mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Automated g4 rollback of changelist 162423171
PiperOrigin-RevId: 162437318
This commit is contained in:
parent
9293c557bd
commit
491beb74cc
|
|
@ -405,7 +405,6 @@ tf_cuda_library(
|
||||||
"util/tensor_slice_reader_cache.h",
|
"util/tensor_slice_reader_cache.h",
|
||||||
"util/tensor_slice_writer.h",
|
"util/tensor_slice_writer.h",
|
||||||
"util/use_cudnn.h",
|
"util/use_cudnn.h",
|
||||||
"util/matmul_autotune.h",
|
|
||||||
"util/util.h",
|
"util/util.h",
|
||||||
"util/work_sharder.h",
|
"util/work_sharder.h",
|
||||||
] + select({
|
] + select({
|
||||||
|
|
|
||||||
|
|
@ -157,7 +157,6 @@ cc_library(
|
||||||
hdrs = ["conv_2d.h"],
|
hdrs = ["conv_2d.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":eigen_helpers",
|
":eigen_helpers",
|
||||||
":gpu_util_hdrs",
|
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
],
|
],
|
||||||
|
|
@ -249,15 +248,6 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "gpu_util_hdrs",
|
|
||||||
hdrs = ["gpu_utils.h"],
|
|
||||||
deps = [
|
|
||||||
":eigen_helpers",
|
|
||||||
"//third_party/eigen3",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "ops_util_test",
|
name = "ops_util_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
|
@ -2416,9 +2406,7 @@ tf_kernel_library(
|
||||||
],
|
],
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
}),
|
}),
|
||||||
deps = MATH_DEPS + [
|
deps = MATH_DEPS + select({
|
||||||
":gpu_util_hdrs",
|
|
||||||
] + select({
|
|
||||||
":xsmm": [
|
":xsmm": [
|
||||||
"@libxsmm_archive//:xsmm_avx",
|
"@libxsmm_archive//:xsmm_avx",
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -21,12 +21,27 @@ limitations under the License.
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/kernels/gpu_utils.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// TODO(zhengxq): move this to gpu_util.h. The use of such wrappers is wide
|
||||||
|
// spread.
|
||||||
|
template <typename T>
|
||||||
|
inline perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
|
||||||
|
uint64 size) {
|
||||||
|
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
|
||||||
|
size * sizeof(T));
|
||||||
|
perftools::gputools::DeviceMemory<T> typed(wrapped);
|
||||||
|
return typed;
|
||||||
|
}
|
||||||
|
|
||||||
// Get the Cudnn workspace limit from the environment variable, which is in MB.
|
// Get the Cudnn workspace limit from the environment variable, which is in MB.
|
||||||
// Return the workspace memory limit in bytes. If no value is set, return the
|
// Return the workspace memory limit in bytes. If no value is set, return the
|
||||||
// default value.
|
// default value.
|
||||||
|
|
@ -41,10 +56,12 @@ class CudnnScratchAllocator : public perftools::gputools::ScratchAllocator {
|
||||||
virtual ~CudnnScratchAllocator() {}
|
virtual ~CudnnScratchAllocator() {}
|
||||||
CudnnScratchAllocator(int64 memory_limit, OpKernelContext* context)
|
CudnnScratchAllocator(int64 memory_limit, OpKernelContext* context)
|
||||||
: memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
|
: memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
|
||||||
int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
|
virtual int64 GetMemoryLimitInBytes(
|
||||||
|
perftools::gputools::Stream* stream) override {
|
||||||
return memory_limit_;
|
return memory_limit_;
|
||||||
}
|
}
|
||||||
perftools::gputools::port::StatusOr<perftools::gputools::DeviceMemory<uint8>>
|
virtual perftools::gputools::port::StatusOr<
|
||||||
|
perftools::gputools::DeviceMemory<uint8>>
|
||||||
AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override {
|
AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override {
|
||||||
Tensor temporary_memory;
|
Tensor temporary_memory;
|
||||||
if (byte_size > memory_limit_) {
|
if (byte_size > memory_limit_) {
|
||||||
|
|
@ -168,6 +185,112 @@ class ConvParameters {
|
||||||
|
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
// A helper class that looks up the best autotuned config from parameters.
|
||||||
|
// Due to the noisy nature of autotune, especially with multiple devices, it
|
||||||
|
// only accepts a config if its margin exceeds a threshold.
|
||||||
|
// For the same shape configs, if a new best config matches the previous best,
|
||||||
|
// they get promoted; otherwise, the winner gets demoted. This process stops
|
||||||
|
// when the winner's score exceeds the threshold.
|
||||||
|
// In a bad case when two configs are very close to each other and flips
|
||||||
|
// back and forth randomly, the expected number of experiments before autotune
|
||||||
|
// settles is O(threshold ^ 2). So we recommend that number of warmup runs
|
||||||
|
// for any benchmarks.
|
||||||
|
template <typename Parameters, typename Config>
|
||||||
|
class AutoTuneMap {
|
||||||
|
public:
|
||||||
|
bool Find(const Parameters& params, Config* config) const {
|
||||||
|
mutex_lock lock(mu_);
|
||||||
|
auto iter = params_config_map_.find(params);
|
||||||
|
if (iter == params_config_map_.end() ||
|
||||||
|
iter->second.score < min_score_threshold_) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*config = iter->second.config;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
void Insert(const ConvParameters& params, const Config& config) {
|
||||||
|
mutex_lock lock(mu_);
|
||||||
|
auto iter = params_config_map_.find(params);
|
||||||
|
int new_score = 0;
|
||||||
|
if (iter == params_config_map_.end()) {
|
||||||
|
// Create a new entry if params is new.
|
||||||
|
VLOG(1) << GetActionSummary("creates", params, config);
|
||||||
|
params_config_map_.insert(std::make_pair(params, ValueType{config, 1}));
|
||||||
|
new_score = 1;
|
||||||
|
} else if (iter->second.score < min_score_threshold_) {
|
||||||
|
DCHECK(iter->second.score > 0);
|
||||||
|
if (iter->second.config != config) {
|
||||||
|
// If it is different from the current winner, demotes the winner.
|
||||||
|
VLOG(1) << GetActionSummary("demotes", params, config);
|
||||||
|
new_score = --iter->second.score;
|
||||||
|
if (new_score <= 0) {
|
||||||
|
VLOG(1) << GetActionSummary("erases", params, config);
|
||||||
|
params_config_map_.erase(iter);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If it is the same as the current winner, promotes the winner.
|
||||||
|
VLOG(1) << GetActionSummary("promotes", params, config);
|
||||||
|
new_score = ++iter->second.score;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (new_score >= min_score_threshold_) {
|
||||||
|
VLOG(1) << GetActionSummary("accepts", params, config);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
AutoTuneMap(const string& name) : name_(name) {
|
||||||
|
min_score_threshold_ = 1;
|
||||||
|
const char* threshold_str = getenv("TF_AUTOTUNE_THRESHOLD");
|
||||||
|
if (threshold_str != nullptr) {
|
||||||
|
strings::safe_strto32(threshold_str, &min_score_threshold_);
|
||||||
|
}
|
||||||
|
min_score_threshold_ = std::max(min_score_threshold_, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Group, class Params, class Cfg>
|
||||||
|
friend class AutoTuneSingleton;
|
||||||
|
|
||||||
|
struct Hasher {
|
||||||
|
std::size_t operator()(const Parameters& parameter) const {
|
||||||
|
return parameter.hash();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
string GetActionSummary(StringPiece action, const Parameters& params,
|
||||||
|
const Config& config) {
|
||||||
|
return strings::Printf("autotune_map %s %s: %s -> (%s)", name_.c_str(),
|
||||||
|
action.ToString().c_str(), params.ToString().c_str(),
|
||||||
|
config.ToString().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
mutable mutex mu_;
|
||||||
|
struct ValueType {
|
||||||
|
Config config;
|
||||||
|
int32 score;
|
||||||
|
};
|
||||||
|
std::unordered_map<Parameters, ValueType, Hasher> params_config_map_
|
||||||
|
GUARDED_BY(mu_);
|
||||||
|
string name_;
|
||||||
|
int32 min_score_threshold_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(AutoTuneMap);
|
||||||
|
};
|
||||||
|
|
||||||
|
// A Singleton helper that manages the global autotune results by groups.
|
||||||
|
// The caller specified arbitrary Group type that can distinguish between
|
||||||
|
// different autotune results, even if their Parameters and Configs are the
|
||||||
|
// same.
|
||||||
|
template <class Group, typename Parameters, typename Config>
|
||||||
|
class AutoTuneSingleton {
|
||||||
|
public:
|
||||||
|
typedef AutoTuneMap<Parameters, Config> AutoTuneType;
|
||||||
|
static AutoTuneType* GetInstance() {
|
||||||
|
static AutoTuneType* instance = new AutoTuneType(Group::name());
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
|
||||||
|
|
@ -1,166 +0,0 @@
|
||||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
|
|
||||||
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
|
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
|
||||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
|
||||||
#include "tensorflow/core/platform/stream_executor.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
|
|
||||||
uint64 size) {
|
|
||||||
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
|
|
||||||
size * sizeof(T));
|
|
||||||
perftools::gputools::DeviceMemory<T> typed(wrapped);
|
|
||||||
return typed;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// A helper class that looks up the best autotuned config from parameters.
|
|
||||||
// Due to the noisy nature of autotune, especially with multiple devices, it
|
|
||||||
// only accepts a config if its margin exceeds a threshold.
|
|
||||||
// For the same shape configs, if a new best config matches the previous best,
|
|
||||||
// they get promoted; otherwise, the winner gets demoted. This process stops
|
|
||||||
// when the winner's score exceeds the threshold.
|
|
||||||
// In a bad case when two configs are very close to each other and flips
|
|
||||||
// back and forth randomly, the expected number of experiments before autotune
|
|
||||||
// settles is O(threshold ^ 2). So we recommend that number of warmup runs
|
|
||||||
// for any benchmarks.
|
|
||||||
template <typename Parameters, typename Config>
|
|
||||||
class AutoTuneMap {
|
|
||||||
public:
|
|
||||||
bool Find(const Parameters& params, Config* config) const {
|
|
||||||
mutex_lock lock(mu_);
|
|
||||||
auto iter = params_config_map_.find(params);
|
|
||||||
if (iter == params_config_map_.end() ||
|
|
||||||
(iter->second.score < min_score_threshold_ &&
|
|
||||||
iter->second.count <= max_autotune_count_)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
*config = iter->second.config;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
void Insert(const Parameters& params, const Config& config) {
|
|
||||||
mutex_lock lock(mu_);
|
|
||||||
auto iter = params_config_map_.find(params);
|
|
||||||
int new_score = 0;
|
|
||||||
if (iter == params_config_map_.end()) {
|
|
||||||
// Create a new entry if params is new.
|
|
||||||
VLOG(1) << GetActionSummary("creates", params, config);
|
|
||||||
params_config_map_.insert(
|
|
||||||
std::make_pair(params, ValueType{config, 1, 1}));
|
|
||||||
new_score = 1;
|
|
||||||
} else if (iter->second.score < min_score_threshold_ &&
|
|
||||||
iter->second.count <= max_autotune_count_) {
|
|
||||||
DCHECK_GT(iter->second.score, 0);
|
|
||||||
if (iter->second.config != config) {
|
|
||||||
// If it is different from the current winner, demotes the winner.
|
|
||||||
VLOG(1) << GetActionSummary("demotes", params, config);
|
|
||||||
new_score = --iter->second.score;
|
|
||||||
++iter->second.count;
|
|
||||||
if (new_score <= 0) {
|
|
||||||
VLOG(1) << GetActionSummary("erases", params, config);
|
|
||||||
params_config_map_.erase(iter);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// If it is the same as the current winner, promotes the winner.
|
|
||||||
VLOG(1) << GetActionSummary("promotes", params, config);
|
|
||||||
new_score = ++iter->second.score;
|
|
||||||
++iter->second.count;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (new_score >= min_score_threshold_) {
|
|
||||||
VLOG(1) << GetActionSummary("accepts", params, config);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
AutoTuneMap(const string& name) : name_(name) {
|
|
||||||
min_score_threshold_ = 1;
|
|
||||||
int min_warmup_iterations = 10;
|
|
||||||
const char* threshold_str = getenv("TF_AUTOTUNE_THRESHOLD");
|
|
||||||
if (threshold_str != nullptr) {
|
|
||||||
strings::safe_strto32(threshold_str, &min_score_threshold_);
|
|
||||||
}
|
|
||||||
const char* min_warmup_iteration_str =
|
|
||||||
getenv("TF_AUTOTUNE_MIN_WARMUP_ITERATIONS");
|
|
||||||
if (min_warmup_iteration_str != nullptr) {
|
|
||||||
strings::safe_strto32(min_warmup_iteration_str, &min_warmup_iterations);
|
|
||||||
}
|
|
||||||
min_score_threshold_ = std::max(min_score_threshold_, 1);
|
|
||||||
max_autotune_count_ = std::max(
|
|
||||||
5 * min_score_threshold_ * min_score_threshold_, min_warmup_iterations);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class Group, class Params, class Cfg>
|
|
||||||
friend class AutoTuneSingleton;
|
|
||||||
|
|
||||||
struct Hasher {
|
|
||||||
std::size_t operator()(const Parameters& parameter) const {
|
|
||||||
return parameter.hash();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
string GetActionSummary(StringPiece action, const Parameters& params,
|
|
||||||
const Config& config) {
|
|
||||||
return strings::Printf("autotune_map %s %s: %s -> (%s)", name_.c_str(),
|
|
||||||
action.ToString().c_str(), params.ToString().c_str(),
|
|
||||||
config.ToString().c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
mutable mutex mu_;
|
|
||||||
struct ValueType {
|
|
||||||
Config config;
|
|
||||||
int32 score;
|
|
||||||
int32 count;
|
|
||||||
};
|
|
||||||
std::unordered_map<Parameters, ValueType, Hasher> params_config_map_
|
|
||||||
GUARDED_BY(mu_);
|
|
||||||
string name_;
|
|
||||||
int32 min_score_threshold_;
|
|
||||||
int32 max_autotune_count_;
|
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(AutoTuneMap);
|
|
||||||
};
|
|
||||||
|
|
||||||
// A Singleton helper that manages the global autotune results by groups.
|
|
||||||
// The caller specified arbitrary Group type that can distinguish between
|
|
||||||
// different autotune results, even if their Parameters and Configs are the
|
|
||||||
// same.
|
|
||||||
template <class Group, typename Parameters, typename Config>
|
|
||||||
class AutoTuneSingleton {
|
|
||||||
public:
|
|
||||||
typedef AutoTuneMap<Parameters, Config> AutoTuneType;
|
|
||||||
static AutoTuneType* GetInstance() {
|
|
||||||
static AutoTuneType* instance = new AutoTuneType(Group::name());
|
|
||||||
return instance;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
|
||||||
|
|
||||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
|
|
||||||
|
|
@ -23,15 +23,27 @@ limitations under the License.
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/kernels/fill_functor.h"
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
#include "tensorflow/core/util/matmul_autotune.h"
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "cuda/include/cuda.h"
|
#include "cuda/include/cuda.h"
|
||||||
#include "tensorflow/core/kernels/gpu_utils.h"
|
|
||||||
#include "tensorflow/core/platform/stream_executor.h"
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename T>
|
||||||
|
perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
|
||||||
|
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
|
||||||
|
perftools::gputools::DeviceMemory<T> typed(wrapped);
|
||||||
|
return typed;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
|
@ -111,16 +123,10 @@ bool ExplicitVectorMatrixOptimization<Eigen::half>(
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
struct LaunchMatMulBase {
|
struct LaunchMatMulBase {
|
||||||
#if GOOGLE_CUDA
|
|
||||||
typedef perftools::gputools::blas::AlgorithmType AlgorithmType;
|
|
||||||
#else
|
|
||||||
typedef int64 AlgorithmType;
|
|
||||||
#endif // GOOGLE_CUDA
|
|
||||||
|
|
||||||
static void launch(
|
static void launch(
|
||||||
OpKernelContext* ctx, const Tensor& a, const Tensor& b,
|
OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b,
|
||||||
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
|
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
|
||||||
std::vector<AlgorithmType>* algorithms, bool use_aututone, Tensor* out) {
|
Tensor* out) {
|
||||||
#ifndef TENSORFLOW_USE_SYCL
|
#ifndef TENSORFLOW_USE_SYCL
|
||||||
// An explicit vector-matrix multiply is much better optimized than an
|
// An explicit vector-matrix multiply is much better optimized than an
|
||||||
// implicit one and this is a bottleneck during non-batched inference.
|
// implicit one and this is a bottleneck during non-batched inference.
|
||||||
|
|
@ -134,10 +140,6 @@ struct LaunchMatMulBase {
|
||||||
}
|
}
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
}
|
}
|
||||||
|
|
||||||
static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
|
|
||||||
std::vector<int64>* algorithms,
|
|
||||||
bool* algorithm_set_flag) {}
|
|
||||||
};
|
};
|
||||||
// On CPUs, we ignore USE_CUBLAS
|
// On CPUs, we ignore USE_CUBLAS
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
@ -157,39 +159,24 @@ struct LaunchMatMul<SYCLDevice, T, USE_CUBLAS> : public LaunchMatMulSYCL<T> {};
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct LaunchBlasGemv {
|
struct LaunchBlasGemv {
|
||||||
static void Compute(
|
static void Compute(OpKernelContext* ctx, perftools::gputools::Stream* stream,
|
||||||
OpKernelContext* ctx, perftools::gputools::Stream* stream, bool trans,
|
bool trans, uint64 m, uint64 n,
|
||||||
uint64 m, uint64 n, const perftools::gputools::DeviceMemory<T>& a,
|
const perftools::gputools::DeviceMemory<T>& a,
|
||||||
const perftools::gputools::DeviceMemory<T>& b,
|
const perftools::gputools::DeviceMemory<T>& b,
|
||||||
perftools::gputools::DeviceMemory<T>* c,
|
perftools::gputools::DeviceMemory<T>* c) {
|
||||||
perftools::gputools::blas::ProfileResult* output_profile) {
|
|
||||||
const auto blas_trans =
|
const auto blas_trans =
|
||||||
trans ? perftools::gputools::blas::Transpose::kTranspose
|
trans ? perftools::gputools::blas::Transpose::kTranspose
|
||||||
: perftools::gputools::blas::Transpose::kNoTranspose;
|
: perftools::gputools::blas::Transpose::kNoTranspose;
|
||||||
if (output_profile == nullptr) {
|
bool blas_launch_status =
|
||||||
bool blas_launch_status =
|
stream
|
||||||
stream
|
->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1,
|
||||||
->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1,
|
static_cast<T>(0.0), c, 1)
|
||||||
static_cast<T>(0.0), c, 1)
|
.ok();
|
||||||
.ok();
|
if (!blas_launch_status) {
|
||||||
if (!blas_launch_status) {
|
ctx->SetStatus(
|
||||||
ctx->SetStatus(
|
errors::Internal("Blas GEMV launch failed: m=", m, ", n=", n));
|
||||||
errors::Internal("Blas GEMV launch failed: m=", m, ", n=", n));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
bool blas_launch_status =
|
|
||||||
stream
|
|
||||||
->ThenBlasGemvWithProfiling(blas_trans, m, n, static_cast<T>(1.0),
|
|
||||||
a, m, b, 1, static_cast<T>(0.0), c, 1,
|
|
||||||
output_profile)
|
|
||||||
.ok();
|
|
||||||
if (!blas_launch_status) {
|
|
||||||
ctx->SetStatus(errors::Internal(
|
|
||||||
"Blas GEMV with profiling launch failed: m=", m, ", n=", n));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -201,8 +188,7 @@ void LaunchBlasGemv<Eigen::half>::Compute(
|
||||||
OpKernelContext* ctx, perftools::gputools::Stream* stream, bool trans,
|
OpKernelContext* ctx, perftools::gputools::Stream* stream, bool trans,
|
||||||
uint64 m, uint64 n, const perftools::gputools::DeviceMemory<Eigen::half>& a,
|
uint64 m, uint64 n, const perftools::gputools::DeviceMemory<Eigen::half>& a,
|
||||||
const perftools::gputools::DeviceMemory<Eigen::half>& b,
|
const perftools::gputools::DeviceMemory<Eigen::half>& b,
|
||||||
perftools::gputools::DeviceMemory<Eigen::half>* c,
|
perftools::gputools::DeviceMemory<Eigen::half>* c) {
|
||||||
perftools::gputools::blas::ProfileResult* output_profile) {
|
|
||||||
ctx->SetStatus(errors::Internal(
|
ctx->SetStatus(errors::Internal(
|
||||||
"Blas GEMV launch failed: GEMV is not implemented for float16."));
|
"Blas GEMV launch failed: GEMV is not implemented for float16."));
|
||||||
}
|
}
|
||||||
|
|
@ -214,55 +200,15 @@ bool LaunchBlasGemv<Eigen::half>::IsSupported() {
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool GetCublasAutotuneComputationType(
|
|
||||||
const DataType& dtype,
|
|
||||||
perftools::gputools::blas::ComputationType* compute_type) {
|
|
||||||
using perftools::gputools::blas::ComputationType;
|
|
||||||
bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input();
|
|
||||||
switch (dtype) {
|
|
||||||
case DT_HALF:
|
|
||||||
case DT_BFLOAT16:
|
|
||||||
if (use_f32_for_f16_computation) {
|
|
||||||
*compute_type = ComputationType::kF32;
|
|
||||||
} else {
|
|
||||||
*compute_type = ComputationType::kF16;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
case DT_FLOAT:
|
|
||||||
*compute_type = ComputationType::kF32;
|
|
||||||
return true;
|
|
||||||
case DT_DOUBLE:
|
|
||||||
*compute_type = ComputationType::kF64;
|
|
||||||
return true;
|
|
||||||
default:
|
|
||||||
// Unsupported compute_type, return false.
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// A dummy type to group matmul autotune results together.
|
|
||||||
struct MatmulAutoTuneGroup {
|
|
||||||
static string name() { return "Matmul"; }
|
|
||||||
};
|
|
||||||
typedef AutoTuneSingleton<MatmulAutoTuneGroup, MatmulParameters,
|
|
||||||
perftools::gputools::blas::AlgorithmConfig>
|
|
||||||
AutoTuneMatmul;
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
|
struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
|
||||||
static void launch(
|
static void launch(
|
||||||
OpKernelContext* ctx, const Tensor& a, const Tensor& b,
|
OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b,
|
||||||
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
|
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
|
||||||
std::vector<int64>* algorithms, bool use_autotune, Tensor* out) {
|
Tensor* out) {
|
||||||
using perftools::gputools::blas::AlgorithmConfig;
|
perftools::gputools::blas::Transpose trans[] = {
|
||||||
using perftools::gputools::blas::ComputationType;
|
perftools::gputools::blas::Transpose::kNoTranspose,
|
||||||
using perftools::gputools::blas::ProfileResult;
|
perftools::gputools::blas::Transpose::kTranspose};
|
||||||
using perftools::gputools::blas::Transpose;
|
|
||||||
using perftools::gputools::blas::kDefaultAlgorithm;
|
|
||||||
using perftools::gputools::blas::kDefaultBlasGemm;
|
|
||||||
using perftools::gputools::blas::kDefaultBlasGemv;
|
|
||||||
using perftools::gputools::blas::kNoAlgorithm;
|
|
||||||
Transpose trans[] = {Transpose::kNoTranspose, Transpose::kTranspose};
|
|
||||||
const uint64 m = a.dim_size(1 - dim_pair[0].first);
|
const uint64 m = a.dim_size(1 - dim_pair[0].first);
|
||||||
const uint64 k = a.dim_size(dim_pair[0].first);
|
const uint64 k = a.dim_size(dim_pair[0].first);
|
||||||
const uint64 n = b.dim_size(1 - dim_pair[0].second);
|
const uint64 n = b.dim_size(1 - dim_pair[0].second);
|
||||||
|
|
@ -274,155 +220,34 @@ struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
|
||||||
auto* stream = ctx->op_device_context()->stream();
|
auto* stream = ctx->op_device_context()->stream();
|
||||||
OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
|
OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
|
||||||
|
|
||||||
auto a_ptr = AsDeviceMemory(a.template flat<T>().data(),
|
auto a_ptr = AsDeviceMemory(a.template flat<T>().data());
|
||||||
a.template flat<T>().size());
|
auto b_ptr = AsDeviceMemory(b.template flat<T>().data());
|
||||||
auto b_ptr = AsDeviceMemory(b.template flat<T>().data(),
|
auto c_ptr = AsDeviceMemory(out->template flat<T>().data());
|
||||||
b.template flat<T>().size());
|
// Cublas does
|
||||||
auto c_ptr = AsDeviceMemory(out->template flat<T>().data(),
|
// C = A x B
|
||||||
out->template flat<T>().size());
|
// where A, B and C are assumed to be in column major.
|
||||||
auto alpha = static_cast<T>(1.0);
|
// We want the output to be in row-major, so we can compute
|
||||||
auto beta = static_cast<T>(0.0);
|
// C' = B' x A' (' stands for transpose)
|
||||||
|
if (LaunchBlasGemv<T>::IsSupported() && n == 1) {
|
||||||
int device_id = stream->parent()->device_ordinal();
|
// This is a matrix*vector multiply so use GEMV to compute A * b.
|
||||||
DataType dtype = a.dtype();
|
// Here we are multiplying in the natural order, so we have to flip
|
||||||
MatmulParameters matmul_parameters = {
|
// the transposition flag to compensate for the tensor being stored
|
||||||
transpose_a, transpose_b, m, n, k, dtype, device_id,
|
// row-major.
|
||||||
};
|
LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a, transpose_a ? m : k,
|
||||||
AlgorithmConfig algorithm_config(kNoAlgorithm);
|
transpose_a ? k : m, a_ptr, b_ptr, &c_ptr);
|
||||||
|
} else {
|
||||||
ComputationType computation_type;
|
bool blas_launch_status =
|
||||||
bool compute_type_supported =
|
stream
|
||||||
GetCublasAutotuneComputationType(dtype, &computation_type);
|
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, 1.0f,
|
||||||
if (use_autotune && compute_type_supported && !algorithms->empty()) {
|
b_ptr, transpose_b ? k : n, a_ptr,
|
||||||
ProfileResult best_result;
|
transpose_a ? m : k, 0.0f, &c_ptr, n)
|
||||||
// TODO(yangzihao): Unify this code with conv autotuning.
|
.ok();
|
||||||
if (!AutoTuneMatmul::GetInstance()->Find(matmul_parameters,
|
if (!blas_launch_status) {
|
||||||
&algorithm_config)) {
|
ctx->SetStatus(errors::Internal(
|
||||||
ProfileResult profile_result;
|
"Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ",
|
||||||
for (auto profile_algorithm : (*algorithms)) {
|
a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1),
|
||||||
// Cublas does
|
"), m=", m, ", n=", n, ", k=", k));
|
||||||
// C = A x B
|
|
||||||
// where A, B and C are assumed to be in column major.
|
|
||||||
// We want the output to be in row-major, so we can compute
|
|
||||||
// C' = B' x A' (' stands for transpose)
|
|
||||||
bool cublas_launch_status =
|
|
||||||
stream
|
|
||||||
->ThenBlasGemmWithAlgorithm(
|
|
||||||
blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
|
|
||||||
transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
|
|
||||||
&c_ptr, n, computation_type, profile_algorithm,
|
|
||||||
&profile_result)
|
|
||||||
.ok();
|
|
||||||
if (cublas_launch_status) {
|
|
||||||
if (profile_result.is_valid()) {
|
|
||||||
if (profile_result.elapsed_time_in_ms() <
|
|
||||||
best_result.elapsed_time_in_ms()) {
|
|
||||||
best_result = profile_result;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Try BlasGemmWithProfiling
|
|
||||||
bool cublas_launch_status =
|
|
||||||
stream
|
|
||||||
->ThenBlasGemmWithProfiling(
|
|
||||||
blas_transpose_b, blas_transpose_a, n, m, k, 1.0, b_ptr,
|
|
||||||
transpose_b ? k : n, a_ptr, transpose_a ? m : k, 0.0,
|
|
||||||
&c_ptr, n, &profile_result)
|
|
||||||
.ok();
|
|
||||||
if (cublas_launch_status) {
|
|
||||||
if (profile_result.is_valid()) {
|
|
||||||
if (profile_result.elapsed_time_in_ms() <
|
|
||||||
best_result.elapsed_time_in_ms()) {
|
|
||||||
best_result = profile_result;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Try BlasGemvWithProfiling
|
|
||||||
if (LaunchBlasGemv<T>::IsSupported() && n == 1) {
|
|
||||||
LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
|
|
||||||
transpose_a ? m : k, transpose_a ? k : m,
|
|
||||||
a_ptr, b_ptr, &c_ptr, &profile_result);
|
|
||||||
if (profile_result.is_valid()) {
|
|
||||||
if (profile_result.elapsed_time_in_ms() <
|
|
||||||
best_result.elapsed_time_in_ms()) {
|
|
||||||
best_result = profile_result;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// We make sure that each matmul parameter set only gets one pass of
|
|
||||||
// autotune. If the best result is found, assign it to algorithm_type
|
|
||||||
// and insert it to autotune map. If all internal kernels of
|
|
||||||
// cublasGemmEx() returns invalid results, we add kNoAlgorithm to the
|
|
||||||
// autotune map.
|
|
||||||
if (best_result.is_valid()) {
|
|
||||||
algorithm_config.set_algorithm(best_result.algorithm());
|
|
||||||
}
|
|
||||||
AutoTuneMatmul::GetInstance()->Insert(matmul_parameters,
|
|
||||||
algorithm_config);
|
|
||||||
if (algorithm_config.algorithm() != kNoAlgorithm &&
|
|
||||||
algorithm_config.algorithm() != kDefaultBlasGemm &&
|
|
||||||
algorithm_config.algorithm() != kDefaultBlasGemv) {
|
|
||||||
bool cublas_launch_status =
|
|
||||||
stream
|
|
||||||
->ThenBlasGemmWithAlgorithm(
|
|
||||||
blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
|
|
||||||
transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
|
|
||||||
&c_ptr, n, computation_type, algorithm_config.algorithm(),
|
|
||||||
nullptr)
|
|
||||||
.ok();
|
|
||||||
if (!cublas_launch_status) {
|
|
||||||
ctx->SetStatus(errors::Internal(
|
|
||||||
"Blas GEMM with algorithm launch failed : a.shape=(",
|
|
||||||
a.dim_size(0), ", ", a.dim_size(1), "), b.shape=(", b.dim_size(0),
|
|
||||||
", ", b.dim_size(1), "), m=", m, ", n=", n, ", k=", k));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// For the following case, we use normal BlasGemm():
|
|
||||||
// 1) We didn't set the use_autotune flag;
|
|
||||||
// 2) compute type does not support autotune;
|
|
||||||
// 3) no algorithm is found;
|
|
||||||
// 4) all internal kernels in autotune return invalid results.
|
|
||||||
if (!use_autotune || !compute_type_supported || algorithms->empty() ||
|
|
||||||
algorithm_config.algorithm() == kNoAlgorithm ||
|
|
||||||
algorithm_config.algorithm() == kDefaultBlasGemm ||
|
|
||||||
algorithm_config.algorithm() == kDefaultBlasGemv) {
|
|
||||||
if (algorithm_config.algorithm() == kDefaultBlasGemv) {
|
|
||||||
// This is a matrix*vector multiply so use GEMV to compute A * b.
|
|
||||||
// Here we are multiplying in the natural order, so we have to flip
|
|
||||||
// the transposition flag to compensate for the tensor being stored
|
|
||||||
// row-major.
|
|
||||||
// TODO(yangzihao): Add Gemv as an autotuning option too.
|
|
||||||
LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
|
|
||||||
transpose_a ? m : k, transpose_a ? k : m,
|
|
||||||
a_ptr, b_ptr, &c_ptr, nullptr);
|
|
||||||
} else {
|
|
||||||
// Use C' = B' x A' (' stands for transpose)
|
|
||||||
bool blas_launch_status =
|
|
||||||
stream
|
|
||||||
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
|
|
||||||
1.0f, b_ptr, transpose_b ? k : n, a_ptr,
|
|
||||||
transpose_a ? m : k, 0.0f, &c_ptr, n)
|
|
||||||
.ok();
|
|
||||||
if (!blas_launch_status) {
|
|
||||||
ctx->SetStatus(errors::Internal(
|
|
||||||
"Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ",
|
|
||||||
a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1),
|
|
||||||
"), m=", m, ", n=", n, ", k=", k));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
|
|
||||||
std::vector<int64>* algorithms,
|
|
||||||
bool* algorithm_set_flag) {
|
|
||||||
if (*algorithm_set_flag == false) {
|
|
||||||
auto* stream = ctx->device()->tensorflow_gpu_device_info()->stream;
|
|
||||||
stream->parent()->GetBlasGemmAlgorithms(algorithms);
|
|
||||||
*algorithm_set_flag = true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -432,14 +257,9 @@ struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
|
||||||
template <typename Device, typename T, bool USE_CUBLAS>
|
template <typename Device, typename T, bool USE_CUBLAS>
|
||||||
class MatMulOp : public OpKernel {
|
class MatMulOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit MatMulOp(OpKernelConstruction* ctx)
|
explicit MatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
: OpKernel(ctx), algorithms_set_already_(false) {
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
|
||||||
|
|
||||||
LaunchMatMul<Device, T, USE_CUBLAS>::GetBlasGemmAlgorithm(
|
|
||||||
ctx, &algorithms_, &algorithms_set_already_);
|
|
||||||
use_autotune_ = MatmulAutotuneEnable();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
|
@ -482,14 +302,10 @@ class MatMulOp : public OpKernel {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
LaunchMatMul<Device, T, USE_CUBLAS>::launch(
|
LaunchMatMul<Device, T, USE_CUBLAS>::launch(ctx, this, a, b, dim_pair, out);
|
||||||
ctx, a, b, dim_pair, &algorithms_, use_autotune_, out);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int64> algorithms_;
|
|
||||||
bool algorithms_set_already_;
|
|
||||||
bool use_autotune_;
|
|
||||||
bool transpose_a_;
|
bool transpose_a_;
|
||||||
bool transpose_b_;
|
bool transpose_b_;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,7 @@ limitations under the License.
|
||||||
#define TENSORFLOW_KERNELS_MATMUL_OP_H_
|
#define TENSORFLOW_KERNELS_MATMUL_OP_H_
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/lib/hash/hash.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
@ -52,68 +50,6 @@ struct MatMulFunctor {
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace functor
|
} // end namespace functor
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
|
||||||
// Encapsulate all the shape information that is used in matmul operations.
|
|
||||||
class MatmulParameters {
|
|
||||||
public:
|
|
||||||
MatmulParameters(bool transa, bool transb, uint64 m, uint64 n, uint64 k,
|
|
||||||
DataType dtype, int device_id)
|
|
||||||
: transa_(transa),
|
|
||||||
transb_(transb),
|
|
||||||
m_(m),
|
|
||||||
n_(n),
|
|
||||||
k_(k),
|
|
||||||
dtype_(dtype),
|
|
||||||
device_id_(device_id) {
|
|
||||||
hash_code_ = transa;
|
|
||||||
hash_code_ = Hash64Combine(hash_code_, transb);
|
|
||||||
hash_code_ = Hash64Combine(hash_code_, m);
|
|
||||||
hash_code_ = Hash64Combine(hash_code_, n);
|
|
||||||
hash_code_ = Hash64Combine(hash_code_, k);
|
|
||||||
hash_code_ = Hash64Combine(hash_code_, dtype);
|
|
||||||
hash_code_ = Hash64Combine(hash_code_, device_id);
|
|
||||||
}
|
|
||||||
bool operator==(const MatmulParameters& other) const {
|
|
||||||
return this->get_data_as_tuple() == other.get_data_as_tuple();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool operator!=(const MatmulParameters& other) const {
|
|
||||||
return !(*this == other);
|
|
||||||
}
|
|
||||||
uint64 hash() const { return hash_code_; }
|
|
||||||
|
|
||||||
string ToString() const {
|
|
||||||
// clang-format off
|
|
||||||
return strings::StrCat(
|
|
||||||
transa_, ", ", transb_, ", ",
|
|
||||||
m_, ", ", n_, ", ", k_,
|
|
||||||
dtype_, ", ", device_id_);
|
|
||||||
// clang-format on
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
typedef std::tuple<bool, bool, int64, int64, int64, DataType, int>
|
|
||||||
ParameterDataType;
|
|
||||||
|
|
||||||
ParameterDataType get_data_as_tuple() const {
|
|
||||||
return std::make_tuple(transa_, transb_, m_, n_, k_, dtype_, device_id_);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool transa_;
|
|
||||||
bool transb_;
|
|
||||||
uint64 m_;
|
|
||||||
uint64 n_;
|
|
||||||
uint64 k_;
|
|
||||||
DataType dtype_;
|
|
||||||
int device_id_;
|
|
||||||
uint64 hash_code_;
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_KERNELS_MATMUL_OP_H_
|
#endif // TENSORFLOW_KERNELS_MATMUL_OP_H_
|
||||||
|
|
|
||||||
|
|
@ -1,51 +0,0 @@
|
||||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/core/util/matmul_autotune.h"
|
|
||||||
|
|
||||||
#include "tensorflow/core/framework/types.h"
|
|
||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
|
||||||
#include "tensorflow/core/util/env_var.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
bool MatmulAutotuneEnable() {
|
|
||||||
bool value;
|
|
||||||
Status status =
|
|
||||||
ReadBoolFromEnvVar("TF_MATMUL_AUTOTUNE_ENABLE", false, &value);
|
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(ERROR) << status.error_message();
|
|
||||||
}
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool MatmulDoFP32ComputationFP16Input() {
|
|
||||||
bool value;
|
|
||||||
// Feedback from NVIDIA: the "true floating point 16" compute capability is
|
|
||||||
// absent from compute capability SM 5.2. The native 16 bit floating point
|
|
||||||
// computation was introduced in SM 5.3 and higher compute capability. So
|
|
||||||
// for compatibility, set this to be true by default for now.
|
|
||||||
// TODO(yangzihao): In the future, we need to return three possibilities:
|
|
||||||
// user-set-true, user-set-false, user-no-setting. In the calling sites,
|
|
||||||
// check the compatibilities. Note that user-set-false with compute
|
|
||||||
// capability <= 5.2 will cause an error in the later cublasGemmEx() call.
|
|
||||||
Status status =
|
|
||||||
ReadBoolFromEnvVar("TF_FP16_MATMUL_USE_FP32_COMPUTE", true, &value);
|
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(ERROR) << status.error_message();
|
|
||||||
}
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
@ -1,28 +0,0 @@
|
||||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
// The utility to check matmul autotune related flags.
|
|
||||||
|
|
||||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_
|
|
||||||
#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
|
|
||||||
bool MatmulAutotuneEnable();
|
|
||||||
bool MatmulDoFP32ComputationFP16Input();
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_
|
|
||||||
|
|
@ -3850,48 +3850,6 @@ cuda_py_test(
|
||||||
main = "ops/transpose_benchmark.py",
|
main = "ops/transpose_benchmark.py",
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
|
||||||
name = "matmul_benchmark",
|
|
||||||
size = "medium",
|
|
||||||
srcs = ["ops/matmul_benchmark.py"],
|
|
||||||
additional_deps = [
|
|
||||||
":math_ops",
|
|
||||||
":random_ops",
|
|
||||||
":client",
|
|
||||||
":client_testlib",
|
|
||||||
":control_flow_ops",
|
|
||||||
":framework_for_generated_wrappers",
|
|
||||||
":framework_test_lib",
|
|
||||||
":platform",
|
|
||||||
":platform_benchmark",
|
|
||||||
":variables",
|
|
||||||
"//third_party/py/numpy",
|
|
||||||
"//tensorflow/core:protos_all_py",
|
|
||||||
],
|
|
||||||
main = "ops/matmul_benchmark.py",
|
|
||||||
)
|
|
||||||
|
|
||||||
cuda_py_test(
|
|
||||||
name = "matmul_benchmark_test",
|
|
||||||
size = "medium",
|
|
||||||
srcs = ["ops/matmul_benchmark_test.py"],
|
|
||||||
additional_deps = [
|
|
||||||
":math_ops",
|
|
||||||
":random_ops",
|
|
||||||
":client",
|
|
||||||
":client_testlib",
|
|
||||||
":control_flow_ops",
|
|
||||||
":framework_for_generated_wrappers",
|
|
||||||
":platform",
|
|
||||||
":platform_benchmark",
|
|
||||||
":matmul_benchmark",
|
|
||||||
":variables",
|
|
||||||
"//third_party/py/numpy",
|
|
||||||
"//tensorflow/core:protos_all_py",
|
|
||||||
],
|
|
||||||
main = "ops/matmul_benchmark_test.py",
|
|
||||||
)
|
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "session_benchmark",
|
name = "session_benchmark",
|
||||||
srcs = ["client/session_benchmark.py"],
|
srcs = ["client/session_benchmark.py"],
|
||||||
|
|
|
||||||
|
|
@ -31,9 +31,6 @@ from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test as test_lib
|
from tensorflow.python.platform import test as test_lib
|
||||||
|
|
||||||
# TODO(yangzihao): Currently matmul autotuning is disabled by default. Use
|
|
||||||
# os.environ["TF_MATMUL_AUTOTUNE_ENABLE"] = "1" to enable it.
|
|
||||||
|
|
||||||
|
|
||||||
def _AddTest(test, op_name, testcase_name, fn):
|
def _AddTest(test, op_name, testcase_name, fn):
|
||||||
test_name = "_".join(["test", op_name, testcase_name])
|
test_name = "_".join(["test", op_name, testcase_name])
|
||||||
|
|
|
||||||
|
|
@ -1,143 +0,0 @@
|
||||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Benchmark for Matmul operator."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import itertools
|
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from tensorflow.python.client import session as session_lib
|
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.ops import control_flow_ops
|
|
||||||
from tensorflow.python.ops import math_ops
|
|
||||||
from tensorflow.python.ops import random_ops
|
|
||||||
from tensorflow.python.ops import variables
|
|
||||||
from tensorflow.python.platform import test
|
|
||||||
|
|
||||||
|
|
||||||
def build_graph(device, n, m, k, transpose_a, transpose_b, dtype):
|
|
||||||
"""Build a graph containing a sequence of matmul operations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
device: String, the device to run on.
|
|
||||||
n: tensor A's first dimension size.
|
|
||||||
m: tensor A's second dimension size.
|
|
||||||
k: tensor B's second dimension size.
|
|
||||||
transpose_a: boolean value to show if tensor A is transposed.
|
|
||||||
transpose_b: boolean value to show if tensor B is transposed.
|
|
||||||
dtype: numpy data type of the input tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A matmul operation to run()
|
|
||||||
"""
|
|
||||||
with ops.device('/%s:0' % device):
|
|
||||||
if not transpose_a:
|
|
||||||
x = variables.Variable(random_ops.random_uniform([n, m], dtype=dtype))
|
|
||||||
else:
|
|
||||||
x = variables.Variable(random_ops.random_uniform([m, n], dtype=dtype))
|
|
||||||
if not transpose_b:
|
|
||||||
y = variables.Variable(random_ops.random_uniform([m, k], dtype=dtype))
|
|
||||||
else:
|
|
||||||
y = variables.Variable(random_ops.random_uniform([k, m], dtype=dtype))
|
|
||||||
|
|
||||||
z = math_ops.matmul(x, y, transpose_a=transpose_a, transpose_b=transpose_b)
|
|
||||||
return control_flow_ops.group(z)
|
|
||||||
|
|
||||||
|
|
||||||
class MatmulBenchmark(test.Benchmark):
|
|
||||||
"""Benchmark matmul!"""
|
|
||||||
|
|
||||||
def run_graph(self, device, n, m, k, transpose_a, transpose_b, num_iters,
|
|
||||||
dtype):
|
|
||||||
"""Run the graph and print its execution time.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
device: String, the device to run on.
|
|
||||||
n: tensor A's first dimension size.
|
|
||||||
m: tensor A's second dimension size.
|
|
||||||
k: tensor B's second dimension size.
|
|
||||||
transpose_a: boolean value to show if tensor A is transposed.
|
|
||||||
transpose_b: boolean value to show if tensor B is transposed.
|
|
||||||
num_iters: number of iterations to run the benchmark.
|
|
||||||
dtype: numpy data type of the input tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The duration of the run in seconds.
|
|
||||||
"""
|
|
||||||
graph = ops.Graph()
|
|
||||||
with graph.as_default():
|
|
||||||
output = build_graph(device, n, m, k, transpose_a, transpose_b, dtype)
|
|
||||||
with session_lib.Session(graph=graph) as session:
|
|
||||||
variables.global_variables_initializer().run()
|
|
||||||
for _ in range(500):
|
|
||||||
session.run(output)
|
|
||||||
start_time = time.time()
|
|
||||||
for _ in range(num_iters):
|
|
||||||
session.run(output)
|
|
||||||
duration = (time.time() - start_time)
|
|
||||||
num_items = n * m * k * 2
|
|
||||||
throughput = num_items * num_iters / duration / 1e9
|
|
||||||
print('%s %s input_info:%s %d %.4fsec, %.4fGitems/s.' %
|
|
||||||
(device, str(dtype), str(n) + 'x' + str(m) + 'x' + str(k) + ',ta:'
|
|
||||||
+ str(transpose_a) + '.tb:' + str(transpose_b), num_iters,
|
|
||||||
duration, throughput))
|
|
||||||
|
|
||||||
name_template = ('matmul_{device}_{dtype}_input_info_{inputinfo}')
|
|
||||||
|
|
||||||
self.report_benchmark(
|
|
||||||
name=name_template.format(
|
|
||||||
device=device,
|
|
||||||
dtype=str(dtype).replace(' ', ''),
|
|
||||||
inputinfo=str(n) + 'x' + str(m) + 'x' + str(k) + ',ta:' +
|
|
||||||
str(transpose_a) + '.tb:' + str(transpose_b)).replace(' ', ''),
|
|
||||||
iters=num_iters,
|
|
||||||
wall_time=duration)
|
|
||||||
return duration
|
|
||||||
|
|
||||||
def run_test_gpu(self, n, m, k, transpose_a, transpose_b, dtype, num_iters):
|
|
||||||
self.run_graph('gpu', n, m, k, transpose_a, transpose_b, num_iters, dtype)
|
|
||||||
|
|
||||||
def test_round(self, num_iters):
|
|
||||||
dtypes = [np.float32, np.float64]
|
|
||||||
for dtype in dtypes:
|
|
||||||
for n, m, (transpose_a, transpose_b) in itertools.product(
|
|
||||||
[512, 1024], [1, 8, 16, 128], [(False, False), (True, False),
|
|
||||||
(False, True)]):
|
|
||||||
k = n
|
|
||||||
self.run_test_gpu(n, m, k, transpose_a, transpose_b, dtype, num_iters)
|
|
||||||
|
|
||||||
for n, m, k, (transpose_a, transpose_b) in itertools.product(
|
|
||||||
[200], [1, 8, 20], [10000], [(False, False), (True, False), (False,
|
|
||||||
True)]):
|
|
||||||
self.run_test_gpu(n, m, k, transpose_a, transpose_b, dtype, num_iters)
|
|
||||||
|
|
||||||
for (n, m, k), (transpose_a, transpose_b) in itertools.product(
|
|
||||||
[(200, 20, 20000), (1, 10000, 200)], [(False, False), (True, False),
|
|
||||||
(False, True)]):
|
|
||||||
self.run_test_gpu(n, m, k, transpose_a, transpose_b, dtype, num_iters)
|
|
||||||
|
|
||||||
def benchmark_matmul(self):
|
|
||||||
num_iters = 200
|
|
||||||
for _ in range(10):
|
|
||||||
self.test_round(num_iters)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test.main()
|
|
||||||
|
|
@ -1,122 +0,0 @@
|
||||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Tests for matmul_benchmark.py."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import itertools
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from tensorflow.core.framework import graph_pb2
|
|
||||||
from tensorflow.core.framework import node_def_pb2
|
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.ops import matmul_benchmark
|
|
||||||
from tensorflow.python.platform import test as googletest
|
|
||||||
from tensorflow.python.platform import tf_logging
|
|
||||||
|
|
||||||
|
|
||||||
def BuildGraphTest(n, m, k, transpose_a, transpose_b, dtype):
|
|
||||||
|
|
||||||
def Test(self):
|
|
||||||
if not googletest.is_gpu_available():
|
|
||||||
tf_logging.info("Skipping BuildGraphTest %s", (n, m, k, transpose_a,
|
|
||||||
transpose_b))
|
|
||||||
return
|
|
||||||
tf_logging.info("Testing BuildGraphTest %s", (n, m, k, transpose_a,
|
|
||||||
transpose_b))
|
|
||||||
self._VerifyBuildGraph(n, m, k, transpose_a, transpose_b, dtype)
|
|
||||||
|
|
||||||
return Test
|
|
||||||
|
|
||||||
|
|
||||||
def RunGraphTest(n, m, k, transpose_a, transpose_b, dtype):
|
|
||||||
|
|
||||||
def Test(self):
|
|
||||||
if not googletest.is_gpu_available():
|
|
||||||
tf_logging.info("Skipping RunGraphTest %s", (n, m, k, transpose_a,
|
|
||||||
transpose_b))
|
|
||||||
return
|
|
||||||
tf_logging.info("Testing RunGraphTest %s", (n, m, k, transpose_a,
|
|
||||||
transpose_b))
|
|
||||||
self._VerifyRunGraph(n, m, k, transpose_a, transpose_b, dtype)
|
|
||||||
|
|
||||||
return Test
|
|
||||||
|
|
||||||
|
|
||||||
class MatmulBenchmarkTest(googletest.TestCase):
|
|
||||||
|
|
||||||
def _StripNode(self, nd):
|
|
||||||
snode = node_def_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input)
|
|
||||||
if nd.device:
|
|
||||||
snode.device = nd.device
|
|
||||||
return snode
|
|
||||||
|
|
||||||
def _StripGraph(self, gd):
|
|
||||||
return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node])
|
|
||||||
|
|
||||||
def _VerifyBuildGraph(self, n, m, k, transpose_a, transpose_b, dtype):
|
|
||||||
graph = ops.Graph()
|
|
||||||
with graph.as_default():
|
|
||||||
matmul_benchmark.build_graph("gpu", n, m, k, transpose_a, transpose_b,
|
|
||||||
dtype)
|
|
||||||
gd = graph.as_graph_def()
|
|
||||||
self.assertProtoEquals("""
|
|
||||||
node { name: "random_uniform/shape" op: "Const" device: "/device:GPU:0" }
|
|
||||||
node { name: "random_uniform/min" op: "Const" device: "/device:GPU:0" }
|
|
||||||
node { name: "random_uniform/max" op: "Const" device: "/device:GPU:0" }
|
|
||||||
node { name: "random_uniform/RandomUniform" op: "RandomUniform" input: "random_uniform/shape" device: "/device:GPU:0" }
|
|
||||||
node { name: "random_uniform/sub" op: "Sub" input: "random_uniform/max" input: "random_uniform/min" device: "/device:GPU:0" }
|
|
||||||
node { name: "random_uniform/mul" op: "Mul" input: "random_uniform/RandomUniform" input: "random_uniform/sub" device: "/device:GPU:0" }
|
|
||||||
node { name: "random_uniform" op: "Add" input: "random_uniform/mul" input: "random_uniform/min" device: "/device:GPU:0" }
|
|
||||||
node { name: "Variable" op: "VariableV2" device: "/device:GPU:0" }
|
|
||||||
node { name: "Variable/Assign" op: "Assign" input: "Variable" input: "random_uniform" device: "/device:GPU:0" }
|
|
||||||
node { name: "Variable/read" op: "Identity" input: "Variable" device: "/device:GPU:0" }
|
|
||||||
node { name: "random_uniform_1/shape" op: "Const" device: "/device:GPU:0" }
|
|
||||||
node { name: "random_uniform_1/min" op: "Const" device: "/device:GPU:0" }
|
|
||||||
node { name: "random_uniform_1/max" op: "Const" device: "/device:GPU:0" }
|
|
||||||
node { name: "random_uniform_1/RandomUniform" op: "RandomUniform" input: "random_uniform_1/shape" device: "/device:GPU:0" }
|
|
||||||
node { name: "random_uniform_1/sub" op: "Sub" input: "random_uniform_1/max" input: "random_uniform_1/min" device: "/device:GPU:0" }
|
|
||||||
node { name: "random_uniform_1/mul" op: "Mul" input: "random_uniform_1/RandomUniform" input: "random_uniform_1/sub" device: "/device:GPU:0" }
|
|
||||||
node { name: "random_uniform_1" op: "Add" input: "random_uniform_1/mul" input: "random_uniform_1/min" device: "/device:GPU:0" }
|
|
||||||
node { name: "Variable_1" op: "VariableV2" device: "/device:GPU:0" }
|
|
||||||
node { name: "Variable_1/Assign" op: "Assign" input: "Variable_1" input: "random_uniform_1" device: "/device:GPU:0" }
|
|
||||||
node { name: "Variable_1/read" op: "Identity" input: "Variable_1" device: "/device:GPU:0" }
|
|
||||||
node { name: "MatMul" op: "MatMul" input: "Variable/read" input: "Variable_1/read" device: "/device:GPU:0" }
|
|
||||||
node { name: "group_deps" op: "NoOp" input: "^MatMul" device: "/device:GPU:0" }
|
|
||||||
""", self._StripGraph(gd))
|
|
||||||
|
|
||||||
def _VerifyRunGraph(self, n, m, k, transpose_a, transpose_b, dtype):
|
|
||||||
benchmark_instance = matmul_benchmark.MatmulBenchmark()
|
|
||||||
duration = benchmark_instance.run_graph("gpu", n, m, k, transpose_a,
|
|
||||||
transpose_b, 1, dtype)
|
|
||||||
self.assertTrue(duration > 1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
dtypes = [np.float32, np.float64]
|
|
||||||
index = 0
|
|
||||||
for _dtype in dtypes:
|
|
||||||
for _n, _m, (_transpose_a, _transpose_b) in itertools.product(
|
|
||||||
[512, 1024], [1, 8, 16, 128], [(False, False), (True, False), (False,
|
|
||||||
True)]):
|
|
||||||
_k = _n
|
|
||||||
setattr(MatmulBenchmarkTest, "testBuildGraph_" + str(index),
|
|
||||||
BuildGraphTest(_n, _m, _k, _transpose_a, _transpose_b, _dtype))
|
|
||||||
setattr(MatmulBenchmarkTest, "testRunGraph_" + str(index),
|
|
||||||
RunGraphTest(_n, _m, _k, _transpose_a, _transpose_b, _dtype))
|
|
||||||
index += 1
|
|
||||||
googletest.main()
|
|
||||||
|
|
@ -67,10 +67,6 @@ string SideString(Side s) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// -- AlgorithmConfig
|
|
||||||
|
|
||||||
string AlgorithmConfig::ToString() const { return port::StrCat(algorithm_); }
|
|
||||||
|
|
||||||
string ComputationTypeString(ComputationType ty) {
|
string ComputationTypeString(ComputationType ty) {
|
||||||
switch (ty) {
|
switch (ty) {
|
||||||
case ComputationType::kF16:
|
case ComputationType::kF16:
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,7 @@ limitations under the License.
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
|
|
||||||
#include "tensorflow/stream_executor/lib/array_slice.h"
|
#include "tensorflow/stream_executor/lib/array_slice.h"
|
||||||
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
struct half;
|
struct half;
|
||||||
|
|
@ -107,10 +108,6 @@ string ComputationTypeString(ComputationType ty);
|
||||||
// Opaque identifier for an "algorithm" used by a blas routine. This functions
|
// Opaque identifier for an "algorithm" used by a blas routine. This functions
|
||||||
// as a hint to the blas library.
|
// as a hint to the blas library.
|
||||||
typedef int64 AlgorithmType;
|
typedef int64 AlgorithmType;
|
||||||
constexpr AlgorithmType kDefaultAlgorithm = -1;
|
|
||||||
constexpr AlgorithmType kDefaultBlasGemm = -2;
|
|
||||||
constexpr AlgorithmType kDefaultBlasGemv = -3;
|
|
||||||
constexpr AlgorithmType kNoAlgorithm = -4;
|
|
||||||
|
|
||||||
// blas uses -1 to represent the default algorithm. This happens to match up
|
// blas uses -1 to represent the default algorithm. This happens to match up
|
||||||
// with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast
|
// with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast
|
||||||
|
|
@ -137,28 +134,10 @@ class ProfileResult {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool is_valid_ = false;
|
bool is_valid_ = false;
|
||||||
AlgorithmType algorithm_ = kDefaultAlgorithm;
|
AlgorithmType algorithm_ = 0;
|
||||||
float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
|
float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
|
||||||
};
|
};
|
||||||
|
|
||||||
class AlgorithmConfig {
|
|
||||||
public:
|
|
||||||
AlgorithmConfig() : algorithm_(kDefaultAlgorithm) {}
|
|
||||||
explicit AlgorithmConfig(AlgorithmType algorithm) : algorithm_(algorithm) {}
|
|
||||||
AlgorithmType algorithm() const { return algorithm_; }
|
|
||||||
void set_algorithm(AlgorithmType val) { algorithm_ = val; }
|
|
||||||
bool operator==(const AlgorithmConfig &other) const {
|
|
||||||
return this->algorithm_ == other.algorithm_;
|
|
||||||
}
|
|
||||||
bool operator!=(const AlgorithmConfig &other) const {
|
|
||||||
return !(*this == other);
|
|
||||||
}
|
|
||||||
string ToString() const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
AlgorithmType algorithm_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// BLAS support interface -- this can be derived from a GPU executor when the
|
// BLAS support interface -- this can be derived from a GPU executor when the
|
||||||
// underlying platform has an BLAS library implementation available. See
|
// underlying platform has an BLAS library implementation available. See
|
||||||
// StreamExecutor::AsBlas().
|
// StreamExecutor::AsBlas().
|
||||||
|
|
@ -474,29 +453,6 @@ class BlasSupport {
|
||||||
std::complex<double> beta,
|
std::complex<double> beta,
|
||||||
DeviceMemory<std::complex<double>> *y, int incy) = 0;
|
DeviceMemory<std::complex<double>> *y, int incy) = 0;
|
||||||
|
|
||||||
virtual bool DoBlasGemvWithProfiling(
|
|
||||||
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
|
|
||||||
const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
|
|
||||||
int incx, float beta, DeviceMemory<float> *y, int incy,
|
|
||||||
ProfileResult *output_profile_result) = 0;
|
|
||||||
virtual bool DoBlasGemvWithProfiling(
|
|
||||||
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
|
|
||||||
const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
|
|
||||||
int incx, double beta, DeviceMemory<double> *y, int incy,
|
|
||||||
ProfileResult *output_profile_result) = 0;
|
|
||||||
virtual bool DoBlasGemvWithProfiling(
|
|
||||||
Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
|
|
||||||
std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
|
|
||||||
int lda, const DeviceMemory<std::complex<float>> &x, int incx,
|
|
||||||
std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
|
|
||||||
ProfileResult *output_profile_result) = 0;
|
|
||||||
virtual bool DoBlasGemvWithProfiling(
|
|
||||||
Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
|
|
||||||
std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
|
|
||||||
int lda, const DeviceMemory<std::complex<double>> &x, int incx,
|
|
||||||
std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
|
|
||||||
int incy, ProfileResult *output_profile_result) = 0;
|
|
||||||
|
|
||||||
// Performs a rank-1 update of a general matrix.
|
// Performs a rank-1 update of a general matrix.
|
||||||
//
|
//
|
||||||
// a <- alpha * x * y' + a,
|
// a <- alpha * x * y' + a,
|
||||||
|
|
@ -979,39 +935,8 @@ class BlasSupport {
|
||||||
std::complex<double> beta,
|
std::complex<double> beta,
|
||||||
DeviceMemory<std::complex<double>> *c, int ldc) = 0;
|
DeviceMemory<std::complex<double>> *c, int ldc) = 0;
|
||||||
|
|
||||||
virtual bool DoBlasGemmWithProfiling(
|
// Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. Note that
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
// any or all of these algorithms may still be
|
||||||
uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
|
|
||||||
int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
|
|
||||||
DeviceMemory<Eigen::half> *c, int ldc,
|
|
||||||
ProfileResult *output_profile_result) = 0;
|
|
||||||
virtual bool DoBlasGemmWithProfiling(
|
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
|
||||||
uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
|
|
||||||
const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
|
|
||||||
int ldc, ProfileResult *output_profile_result) = 0;
|
|
||||||
virtual bool DoBlasGemmWithProfiling(
|
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
|
||||||
uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
|
|
||||||
const DeviceMemory<double> &b, int ldb, double beta,
|
|
||||||
DeviceMemory<double> *c, int ldc,
|
|
||||||
ProfileResult *output_profile_result) = 0;
|
|
||||||
virtual bool DoBlasGemmWithProfiling(
|
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
|
||||||
uint64 n, uint64 k, std::complex<float> alpha,
|
|
||||||
const DeviceMemory<std::complex<float>> &a, int lda,
|
|
||||||
const DeviceMemory<std::complex<float>> &b, int ldb,
|
|
||||||
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
|
|
||||||
ProfileResult *output_profile_result) = 0;
|
|
||||||
virtual bool DoBlasGemmWithProfiling(
|
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
|
||||||
uint64 n, uint64 k, std::complex<double> alpha,
|
|
||||||
const DeviceMemory<std::complex<double>> &a, int lda,
|
|
||||||
const DeviceMemory<std::complex<double>> &b, int ldb,
|
|
||||||
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
|
|
||||||
ProfileResult *output_profile_result) = 0;
|
|
||||||
|
|
||||||
// Gets a list of supported algorithms for DoBlasGemmWithAlgorithm.
|
|
||||||
virtual bool GetBlasGemmAlgorithms(
|
virtual bool GetBlasGemmAlgorithms(
|
||||||
std::vector<AlgorithmType> *out_algorithms) = 0;
|
std::vector<AlgorithmType> *out_algorithms) = 0;
|
||||||
|
|
||||||
|
|
@ -1548,28 +1473,6 @@ class BlasSupport {
|
||||||
const DeviceMemory<std::complex<double>> &x, int incx, \
|
const DeviceMemory<std::complex<double>> &x, int incx, \
|
||||||
std::complex<double> beta, \
|
std::complex<double> beta, \
|
||||||
DeviceMemory<std::complex<double>> *y, int incy) override; \
|
DeviceMemory<std::complex<double>> *y, int incy) override; \
|
||||||
bool DoBlasGemvWithProfiling( \
|
|
||||||
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, \
|
|
||||||
const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, \
|
|
||||||
int incx, float beta, DeviceMemory<float> *y, int incy, \
|
|
||||||
blas::ProfileResult *output_profile_result) override; \
|
|
||||||
bool DoBlasGemvWithProfiling( \
|
|
||||||
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, \
|
|
||||||
const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, \
|
|
||||||
int incx, double beta, DeviceMemory<double> *y, int incy, \
|
|
||||||
blas::ProfileResult *output_profile_result) override; \
|
|
||||||
bool DoBlasGemvWithProfiling( \
|
|
||||||
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \
|
|
||||||
std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, \
|
|
||||||
int lda, const DeviceMemory<std::complex<float>> &x, int incx, \
|
|
||||||
std::complex<float> beta, DeviceMemory<std::complex<float>> *y, \
|
|
||||||
int incy, blas::ProfileResult *output_profile_result) override; \
|
|
||||||
bool DoBlasGemvWithProfiling( \
|
|
||||||
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \
|
|
||||||
std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, \
|
|
||||||
int lda, const DeviceMemory<std::complex<double>> &x, int incx, \
|
|
||||||
std::complex<double> beta, DeviceMemory<std::complex<double>> *y, \
|
|
||||||
int incy, blas::ProfileResult *output_profile_result) override; \
|
|
||||||
bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, \
|
bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, \
|
||||||
const DeviceMemory<float> &x, int incx, \
|
const DeviceMemory<float> &x, int incx, \
|
||||||
const DeviceMemory<float> &y, int incy, \
|
const DeviceMemory<float> &y, int incy, \
|
||||||
|
|
@ -1848,39 +1751,6 @@ class BlasSupport {
|
||||||
const DeviceMemory<std::complex<double>> &b, int ldb, \
|
const DeviceMemory<std::complex<double>> &b, int ldb, \
|
||||||
std::complex<double> beta, \
|
std::complex<double> beta, \
|
||||||
DeviceMemory<std::complex<double>> *c, int ldc) override; \
|
DeviceMemory<std::complex<double>> *c, int ldc) override; \
|
||||||
bool DoBlasGemmWithProfiling( \
|
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
|
|
||||||
uint64 m, uint64 n, uint64 k, float alpha, \
|
|
||||||
const DeviceMemory<Eigen::half> &a, int lda, \
|
|
||||||
const DeviceMemory<Eigen::half> &b, int ldb, float beta, \
|
|
||||||
DeviceMemory<Eigen::half> *c, int ldc, \
|
|
||||||
blas::ProfileResult *output_profile_result) override; \
|
|
||||||
bool DoBlasGemmWithProfiling( \
|
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
|
|
||||||
uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \
|
|
||||||
int lda, const DeviceMemory<float> &b, int ldb, float beta, \
|
|
||||||
DeviceMemory<float> *c, int ldc, \
|
|
||||||
blas::ProfileResult *output_profile_result) override; \
|
|
||||||
bool DoBlasGemmWithProfiling( \
|
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
|
|
||||||
uint64 m, uint64 n, uint64 k, double alpha, \
|
|
||||||
const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, \
|
|
||||||
int ldb, double beta, DeviceMemory<double> *c, int ldc, \
|
|
||||||
blas::ProfileResult *output_profile_result) override; \
|
|
||||||
bool DoBlasGemmWithProfiling( \
|
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
|
|
||||||
uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \
|
|
||||||
const DeviceMemory<std::complex<float>> &a, int lda, \
|
|
||||||
const DeviceMemory<std::complex<float>> &b, int ldb, \
|
|
||||||
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
|
|
||||||
blas::ProfileResult *output_profile_result) override; \
|
|
||||||
bool DoBlasGemmWithProfiling( \
|
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
|
|
||||||
uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \
|
|
||||||
const DeviceMemory<std::complex<double>> &a, int lda, \
|
|
||||||
const DeviceMemory<std::complex<double>> &b, int ldb, \
|
|
||||||
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \
|
|
||||||
int ldc, blas::ProfileResult *output_profile_result) override; \
|
|
||||||
bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \
|
bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \
|
||||||
override; \
|
override; \
|
||||||
bool DoBlasGemmWithAlgorithm( \
|
bool DoBlasGemmWithAlgorithm( \
|
||||||
|
|
|
||||||
|
|
@ -1857,180 +1857,6 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
||||||
CUDAComplex(CUDAMemoryMutable(c)), ldc);
|
CUDAComplex(CUDAMemoryMutable(c)), ldc);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CUDABlas::DoBlasGemvWithProfiling(
|
|
||||||
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
|
|
||||||
const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
|
|
||||||
int incx, float beta, DeviceMemory<float> *y, int incy,
|
|
||||||
blas::ProfileResult *output_profile_result) {
|
|
||||||
return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
|
|
||||||
incx, beta, y, incy,
|
|
||||||
output_profile_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool CUDABlas::DoBlasGemvWithProfiling(
|
|
||||||
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
|
|
||||||
const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
|
|
||||||
int incx, double beta, DeviceMemory<double> *y, int incy,
|
|
||||||
blas::ProfileResult *output_profile_result) {
|
|
||||||
return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
|
|
||||||
incx, beta, y, incy,
|
|
||||||
output_profile_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool CUDABlas::DoBlasGemvWithProfiling(
|
|
||||||
Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
|
|
||||||
std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
|
|
||||||
int lda, const DeviceMemory<std::complex<float>> &x, int incx,
|
|
||||||
std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
|
|
||||||
blas::ProfileResult *output_profile_result) {
|
|
||||||
return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
|
|
||||||
incx, beta, y, incy,
|
|
||||||
output_profile_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool CUDABlas::DoBlasGemvWithProfiling(
|
|
||||||
Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
|
|
||||||
std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
|
|
||||||
int lda, const DeviceMemory<std::complex<double>> &x, int incx,
|
|
||||||
std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
|
|
||||||
blas::ProfileResult *output_profile_result) {
|
|
||||||
return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
|
|
||||||
incx, beta, y, incy,
|
|
||||||
output_profile_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool CUDABlas::DoBlasGemmWithProfiling(
|
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
|
||||||
uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
|
|
||||||
int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
|
|
||||||
DeviceMemory<Eigen::half> *c, int ldc,
|
|
||||||
blas::ProfileResult *output_profile_result) {
|
|
||||||
return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
|
|
||||||
lda, b, ldb, beta, c, ldc,
|
|
||||||
output_profile_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool CUDABlas::DoBlasGemmWithProfiling(
|
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
|
||||||
uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
|
|
||||||
const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
|
|
||||||
int ldc, blas::ProfileResult *output_profile_result) {
|
|
||||||
return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
|
|
||||||
lda, b, ldb, beta, c, ldc,
|
|
||||||
output_profile_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool CUDABlas::DoBlasGemmWithProfiling(
|
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
|
||||||
uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
|
|
||||||
const DeviceMemory<double> &b, int ldb, double beta,
|
|
||||||
DeviceMemory<double> *c, int ldc,
|
|
||||||
blas::ProfileResult *output_profile_result) {
|
|
||||||
return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
|
|
||||||
lda, b, ldb, beta, c, ldc,
|
|
||||||
output_profile_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool CUDABlas::DoBlasGemmWithProfiling(
|
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
|
||||||
uint64 n, uint64 k, std::complex<float> alpha,
|
|
||||||
const DeviceMemory<std::complex<float>> &a, int lda,
|
|
||||||
const DeviceMemory<std::complex<float>> &b, int ldb,
|
|
||||||
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
|
|
||||||
blas::ProfileResult *output_profile_result) {
|
|
||||||
return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
|
|
||||||
lda, b, ldb, beta, c, ldc,
|
|
||||||
output_profile_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool CUDABlas::DoBlasGemmWithProfiling(
|
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
|
||||||
uint64 n, uint64 k, std::complex<double> alpha,
|
|
||||||
const DeviceMemory<std::complex<double>> &a, int lda,
|
|
||||||
const DeviceMemory<std::complex<double>> &b, int ldb,
|
|
||||||
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
|
|
||||||
blas::ProfileResult *output_profile_result) {
|
|
||||||
return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
|
|
||||||
lda, b, ldb, beta, c, ldc,
|
|
||||||
output_profile_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
bool CUDABlas::DoBlasGemvWithProfilingImpl(
|
|
||||||
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, const T &alpha,
|
|
||||||
const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
|
|
||||||
const T &beta, DeviceMemory<T> *y, int incy,
|
|
||||||
blas::ProfileResult *output_profile_result) {
|
|
||||||
struct TimerDeleter {
|
|
||||||
void operator()(CUDATimer *t) {
|
|
||||||
t->Destroy();
|
|
||||||
delete t;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
std::unique_ptr<CUDATimer, TimerDeleter> timer;
|
|
||||||
if (output_profile_result != nullptr) {
|
|
||||||
timer.reset(new CUDATimer(parent_));
|
|
||||||
if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call blasGemm
|
|
||||||
bool result =
|
|
||||||
DoBlasGemv(stream, trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
|
|
||||||
|
|
||||||
if (timer != nullptr && result) {
|
|
||||||
// CUDATimer will CHECK-fail if we Stop() it while the stream is in an error
|
|
||||||
// state.
|
|
||||||
if (!timer->Stop(AsCUDAStream(stream))) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
output_profile_result->set_is_valid(true);
|
|
||||||
output_profile_result->set_algorithm(blas::kDefaultBlasGemv);
|
|
||||||
output_profile_result->set_elapsed_time_in_ms(
|
|
||||||
timer->GetElapsedMilliseconds());
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename ParamType>
|
|
||||||
bool CUDABlas::DoBlasGemmWithProfilingImpl(
|
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
|
||||||
uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
|
|
||||||
int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
|
|
||||||
DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
|
|
||||||
struct TimerDeleter {
|
|
||||||
void operator()(CUDATimer *t) {
|
|
||||||
t->Destroy();
|
|
||||||
delete t;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
std::unique_ptr<CUDATimer, TimerDeleter> timer;
|
|
||||||
if (output_profile_result != nullptr) {
|
|
||||||
timer.reset(new CUDATimer(parent_));
|
|
||||||
if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call blasGemm
|
|
||||||
bool result = DoBlasGemm(stream, transa, transb, m, n, k, alpha, a, lda, b,
|
|
||||||
ldb, beta, c, ldc);
|
|
||||||
|
|
||||||
if (timer != nullptr && result) {
|
|
||||||
// CUDATimer will CHECK-fail if we Stop() it while the stream is in an error
|
|
||||||
// state.
|
|
||||||
if (!timer->Stop(AsCUDAStream(stream))) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
output_profile_result->set_is_valid(true);
|
|
||||||
output_profile_result->set_algorithm(blas::kDefaultBlasGemm);
|
|
||||||
output_profile_result->set_elapsed_time_in_ms(
|
|
||||||
timer->GetElapsedMilliseconds());
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename InT, typename OutT, typename CompT>
|
template <typename InT, typename OutT, typename CompT>
|
||||||
bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
|
bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
|
|
@ -2094,9 +1920,6 @@ bool CUDABlas::GetBlasGemmAlgorithms(
|
||||||
std::vector<blas::AlgorithmType> *out_algorithms) {
|
std::vector<blas::AlgorithmType> *out_algorithms) {
|
||||||
// cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
|
// cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
|
||||||
// were first introduced in CUDA 8.
|
// were first introduced in CUDA 8.
|
||||||
// Note that when CUDA version and compute capability is not sufficient, we
|
|
||||||
// still return the out_algorithms. Caller needs to make sure that in this case,
|
|
||||||
// the returned vector is empty.
|
|
||||||
#if CUDA_VERSION >= 8000
|
#if CUDA_VERSION >= 8000
|
||||||
for (cublasGemmAlgo_t algo :
|
for (cublasGemmAlgo_t algo :
|
||||||
{CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1,
|
{CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1,
|
||||||
|
|
@ -2104,10 +1927,8 @@ bool CUDABlas::GetBlasGemmAlgorithms(
|
||||||
CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7}) {
|
CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7}) {
|
||||||
out_algorithms->push_back(algo);
|
out_algorithms->push_back(algo);
|
||||||
}
|
}
|
||||||
return true;
|
|
||||||
#else
|
|
||||||
return false;
|
|
||||||
#endif
|
#endif
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CUDABlas::DoBlasGemmWithAlgorithm(
|
bool CUDABlas::DoBlasGemmWithAlgorithm(
|
||||||
|
|
|
||||||
|
|
@ -127,23 +127,6 @@ class CUDABlas : public blas::BlasSupport {
|
||||||
blas::AlgorithmType algorithm,
|
blas::AlgorithmType algorithm,
|
||||||
blas::ProfileResult *output_profile_result);
|
blas::ProfileResult *output_profile_result);
|
||||||
|
|
||||||
// Helper function for implementing DoBlasGemmWithProfiling.
|
|
||||||
template <typename T, typename ParamType>
|
|
||||||
bool DoBlasGemmWithProfilingImpl(
|
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
|
||||||
uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
|
|
||||||
int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
|
|
||||||
DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result);
|
|
||||||
|
|
||||||
// Helper function for implementing DoBlasGemvWithProfiling.
|
|
||||||
template <typename T>
|
|
||||||
bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans,
|
|
||||||
uint64 m, uint64 n, const T &alpha,
|
|
||||||
const DeviceMemory<T> &a, int lda,
|
|
||||||
const DeviceMemory<T> &x, int incx,
|
|
||||||
const T &beta, DeviceMemory<T> *y, int incy,
|
|
||||||
blas::ProfileResult *output_profile_result);
|
|
||||||
|
|
||||||
// mutex that guards the cuBLAS handle for this device.
|
// mutex that guards the cuBLAS handle for this device.
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3458,184 +3458,6 @@ struct ThenBlasWithProfileImpl {
|
||||||
};
|
};
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
Stream &Stream::ThenBlasGemvWithProfiling(
|
|
||||||
blas::Transpose trans, uint64 m, uint64 n, float alpha,
|
|
||||||
const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
|
|
||||||
int incx, float beta, DeviceMemory<float> *y, int incy,
|
|
||||||
blas::ProfileResult *output_profile_result) {
|
|
||||||
VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
|
|
||||||
PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
|
|
||||||
PARAM(incy));
|
|
||||||
|
|
||||||
ThenBlasWithProfileImpl<
|
|
||||||
blas::Transpose, uint64, uint64, float, const DeviceMemory<float> &, int,
|
|
||||||
const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int>
|
|
||||||
impl;
|
|
||||||
return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
|
|
||||||
alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
Stream &Stream::ThenBlasGemvWithProfiling(
|
|
||||||
blas::Transpose trans, uint64 m, uint64 n, double alpha,
|
|
||||||
const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
|
|
||||||
int incx, double beta, DeviceMemory<double> *y, int incy,
|
|
||||||
blas::ProfileResult *output_profile_result) {
|
|
||||||
VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
|
|
||||||
PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
|
|
||||||
PARAM(incy));
|
|
||||||
|
|
||||||
ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, double,
|
|
||||||
const DeviceMemory<double> &, int,
|
|
||||||
const DeviceMemory<double> &, int, double,
|
|
||||||
DeviceMemory<double> *, int>
|
|
||||||
impl;
|
|
||||||
return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
|
|
||||||
alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
Stream &Stream::ThenBlasGemvWithProfiling(
|
|
||||||
blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
|
|
||||||
const DeviceMemory<std::complex<float>> &a, int lda,
|
|
||||||
const DeviceMemory<std::complex<float>> &x, int incx,
|
|
||||||
std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
|
|
||||||
blas::ProfileResult *output_profile_result) {
|
|
||||||
VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
|
|
||||||
PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
|
|
||||||
PARAM(incy));
|
|
||||||
|
|
||||||
ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<float>,
|
|
||||||
const DeviceMemory<std::complex<float>> &, int,
|
|
||||||
const DeviceMemory<std::complex<float>> &, int,
|
|
||||||
std::complex<float>,
|
|
||||||
DeviceMemory<std::complex<float>> *, int>
|
|
||||||
impl;
|
|
||||||
return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
|
|
||||||
alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
Stream &Stream::ThenBlasGemvWithProfiling(
|
|
||||||
blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
|
|
||||||
const DeviceMemory<std::complex<double>> &a, int lda,
|
|
||||||
const DeviceMemory<std::complex<double>> &x, int incx,
|
|
||||||
std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
|
|
||||||
blas::ProfileResult *output_profile_result) {
|
|
||||||
VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
|
|
||||||
PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
|
|
||||||
PARAM(incy));
|
|
||||||
|
|
||||||
ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<double>,
|
|
||||||
const DeviceMemory<std::complex<double>> &, int,
|
|
||||||
const DeviceMemory<std::complex<double>> &, int,
|
|
||||||
std::complex<double>,
|
|
||||||
DeviceMemory<std::complex<double>> *, int>
|
|
||||||
impl;
|
|
||||||
return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
|
|
||||||
alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
Stream &Stream::ThenBlasGemmWithProfiling(
|
|
||||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
|
||||||
uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
|
|
||||||
const DeviceMemory<Eigen::half> &b, int ldb, float beta,
|
|
||||||
DeviceMemory<Eigen::half> *c, int ldc,
|
|
||||||
blas::ProfileResult *output_profile_result) {
|
|
||||||
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
|
|
||||||
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
|
|
||||||
PARAM(beta), PARAM(c), PARAM(ldc));
|
|
||||||
|
|
||||||
ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
|
|
||||||
uint64, float, const DeviceMemory<Eigen::half> &, int,
|
|
||||||
const DeviceMemory<Eigen::half> &, int, float,
|
|
||||||
DeviceMemory<Eigen::half> *, int>
|
|
||||||
impl;
|
|
||||||
return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
|
|
||||||
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
|
||||||
output_profile_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
Stream &Stream::ThenBlasGemmWithProfiling(
|
|
||||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
|
||||||
uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
|
|
||||||
const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
|
|
||||||
int ldc, blas::ProfileResult *output_profile_result) {
|
|
||||||
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
|
|
||||||
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
|
|
||||||
PARAM(beta), PARAM(c), PARAM(ldc));
|
|
||||||
|
|
||||||
ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
|
|
||||||
uint64, float, const DeviceMemory<float> &, int,
|
|
||||||
const DeviceMemory<float> &, int, float,
|
|
||||||
DeviceMemory<float> *, int>
|
|
||||||
impl;
|
|
||||||
return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
|
|
||||||
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
|
||||||
output_profile_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
Stream &Stream::ThenBlasGemmWithProfiling(
|
|
||||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
|
||||||
uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
|
|
||||||
const DeviceMemory<double> &b, int ldb, double beta,
|
|
||||||
DeviceMemory<double> *c, int ldc,
|
|
||||||
blas::ProfileResult *output_profile_result) {
|
|
||||||
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
|
|
||||||
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
|
|
||||||
PARAM(beta), PARAM(c), PARAM(ldc));
|
|
||||||
|
|
||||||
ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
|
|
||||||
uint64, double, const DeviceMemory<double> &, int,
|
|
||||||
const DeviceMemory<double> &, int, double,
|
|
||||||
DeviceMemory<double> *, int>
|
|
||||||
impl;
|
|
||||||
return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
|
|
||||||
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
|
||||||
output_profile_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
Stream &Stream::ThenBlasGemmWithProfiling(
|
|
||||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
|
||||||
uint64 k, std::complex<float> alpha,
|
|
||||||
const DeviceMemory<std::complex<float>> &a, int lda,
|
|
||||||
const DeviceMemory<std::complex<float>> &b, int ldb,
|
|
||||||
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
|
|
||||||
blas::ProfileResult *output_profile_result) {
|
|
||||||
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
|
|
||||||
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
|
|
||||||
PARAM(beta), PARAM(c), PARAM(ldc));
|
|
||||||
|
|
||||||
ThenBlasWithProfileImpl<
|
|
||||||
blas::Transpose, blas::Transpose, uint64, uint64, uint64,
|
|
||||||
std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
|
|
||||||
const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
|
|
||||||
DeviceMemory<std::complex<float>> *, int>
|
|
||||||
impl;
|
|
||||||
return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
|
|
||||||
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
|
||||||
output_profile_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
Stream &Stream::ThenBlasGemmWithProfiling(
|
|
||||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
|
||||||
uint64 k, std::complex<double> alpha,
|
|
||||||
const DeviceMemory<std::complex<double>> &a, int lda,
|
|
||||||
const DeviceMemory<std::complex<double>> &b, int ldb,
|
|
||||||
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
|
|
||||||
blas::ProfileResult *output_profile_result) {
|
|
||||||
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
|
|
||||||
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
|
|
||||||
PARAM(beta), PARAM(c), PARAM(ldc));
|
|
||||||
|
|
||||||
ThenBlasWithProfileImpl<
|
|
||||||
blas::Transpose, blas::Transpose, uint64, uint64, uint64,
|
|
||||||
std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
|
|
||||||
const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
|
|
||||||
DeviceMemory<std::complex<double>> *, int>
|
|
||||||
impl;
|
|
||||||
return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
|
|
||||||
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
|
||||||
output_profile_result);
|
|
||||||
}
|
|
||||||
|
|
||||||
Stream &Stream::ThenBlasGemmWithAlgorithm(
|
Stream &Stream::ThenBlasGemmWithAlgorithm(
|
||||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||||
uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a,
|
uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a,
|
||||||
|
|
|
||||||
|
|
@ -934,31 +934,6 @@ class Stream {
|
||||||
std::complex<double> beta,
|
std::complex<double> beta,
|
||||||
DeviceMemory<std::complex<double>> *y, int incy);
|
DeviceMemory<std::complex<double>> *y, int incy);
|
||||||
|
|
||||||
Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
|
|
||||||
float alpha, const DeviceMemory<float> &a,
|
|
||||||
int lda, const DeviceMemory<float> &x,
|
|
||||||
int incx, float beta,
|
|
||||||
DeviceMemory<float> *y, int incy,
|
|
||||||
blas::ProfileResult *output_profile_result);
|
|
||||||
Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
|
|
||||||
double alpha, const DeviceMemory<double> &a,
|
|
||||||
int lda, const DeviceMemory<double> &x,
|
|
||||||
int incx, double beta,
|
|
||||||
DeviceMemory<double> *y, int incy,
|
|
||||||
blas::ProfileResult *output_profile_result);
|
|
||||||
Stream &ThenBlasGemvWithProfiling(
|
|
||||||
blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
|
|
||||||
const DeviceMemory<std::complex<float>> &a, int lda,
|
|
||||||
const DeviceMemory<std::complex<float>> &x, int incx,
|
|
||||||
std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
|
|
||||||
blas::ProfileResult *output_profile_result);
|
|
||||||
Stream &ThenBlasGemvWithProfiling(
|
|
||||||
blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
|
|
||||||
const DeviceMemory<std::complex<double>> &a, int lda,
|
|
||||||
const DeviceMemory<std::complex<double>> &x, int incx,
|
|
||||||
std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
|
|
||||||
int incy, blas::ProfileResult *output_profile_result);
|
|
||||||
|
|
||||||
// See BlasSupport::DoBlasGer.
|
// See BlasSupport::DoBlasGer.
|
||||||
Stream &ThenBlasGer(uint64 m, uint64 n, float alpha,
|
Stream &ThenBlasGer(uint64 m, uint64 n, float alpha,
|
||||||
const DeviceMemory<float> &x, int incx,
|
const DeviceMemory<float> &x, int incx,
|
||||||
|
|
@ -1274,44 +1249,6 @@ class Stream {
|
||||||
std::complex<double> beta,
|
std::complex<double> beta,
|
||||||
DeviceMemory<std::complex<double>> *c, int ldc);
|
DeviceMemory<std::complex<double>> *c, int ldc);
|
||||||
|
|
||||||
Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
|
|
||||||
blas::Transpose transb, uint64 m, uint64 n,
|
|
||||||
uint64 k, float alpha,
|
|
||||||
const DeviceMemory<Eigen::half> &a, int lda,
|
|
||||||
const DeviceMemory<Eigen::half> &b, int ldb,
|
|
||||||
float beta, DeviceMemory<Eigen::half> *c,
|
|
||||||
int ldc,
|
|
||||||
blas::ProfileResult *output_profile_result);
|
|
||||||
Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
|
|
||||||
blas::Transpose transb, uint64 m, uint64 n,
|
|
||||||
uint64 k, float alpha,
|
|
||||||
const DeviceMemory<float> &a, int lda,
|
|
||||||
const DeviceMemory<float> &b, int ldb,
|
|
||||||
float beta, DeviceMemory<float> *c, int ldc,
|
|
||||||
blas::ProfileResult *output_profile_result);
|
|
||||||
Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
|
|
||||||
blas::Transpose transb, uint64 m, uint64 n,
|
|
||||||
uint64 k, double alpha,
|
|
||||||
const DeviceMemory<double> &a, int lda,
|
|
||||||
const DeviceMemory<double> &b, int ldb,
|
|
||||||
double beta, DeviceMemory<double> *c,
|
|
||||||
int ldc,
|
|
||||||
blas::ProfileResult *output_profile_result);
|
|
||||||
Stream &ThenBlasGemmWithProfiling(
|
|
||||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
|
||||||
uint64 k, std::complex<float> alpha,
|
|
||||||
const DeviceMemory<std::complex<float>> &a, int lda,
|
|
||||||
const DeviceMemory<std::complex<float>> &b, int ldb,
|
|
||||||
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
|
|
||||||
blas::ProfileResult *output_profile_result);
|
|
||||||
Stream &ThenBlasGemmWithProfiling(
|
|
||||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
|
||||||
uint64 k, std::complex<double> alpha,
|
|
||||||
const DeviceMemory<std::complex<double>> &a, int lda,
|
|
||||||
const DeviceMemory<std::complex<double>> &b, int ldb,
|
|
||||||
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
|
|
||||||
blas::ProfileResult *output_profile_result);
|
|
||||||
|
|
||||||
// See BlasSupport::DoBlasGemmWithAlgorithm.
|
// See BlasSupport::DoBlasGemmWithAlgorithm.
|
||||||
Stream &ThenBlasGemmWithAlgorithm(
|
Stream &ThenBlasGemmWithAlgorithm(
|
||||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user