mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
add Cuda{2D,3D}LaunchConfig that maximizes occupancy (#10032)
* add Cuda{2D,3D}LaunchConfig that max occupancy
* remove default val, check input<=0
* add max size check
* fix typo
* tests, docs, and related changes
* build the test
* buildify
* cudaOccupancy... call check success, and style fix
This commit is contained in:
parent
187d233374
commit
b440abce7f
|
|
@ -82,6 +82,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl")
|
|||
load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test")
|
||||
|
||||
# For platform specific build config
|
||||
load(
|
||||
|
|
@ -2323,6 +2324,18 @@ tf_cc_test_gpu(
|
|||
],
|
||||
)
|
||||
|
||||
tf_cuda_only_cc_test(
|
||||
name = "util_cuda_kernel_helper_test",
|
||||
srcs = [
|
||||
"util/cuda_kernel_helper_test.cu.cc",
|
||||
],
|
||||
deps = [
|
||||
":test",
|
||||
":test_main",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test_gpu(
|
||||
name = "memory_types_test",
|
||||
size = "small",
|
||||
|
|
|
|||
|
|
@ -20,13 +20,95 @@ limitations under the License.
|
|||
|
||||
#include <algorithm>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/platform/default/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and
|
||||
// GetCuda3DLaunchConfig:
|
||||
//
|
||||
// There are two versions of GetCudaLaunchConfig and GetCuda2DLaunchConfig, one
|
||||
// version uses heuristics without any knowledge of the device kernel, the other
|
||||
// version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical
|
||||
// launch parameters that maximize occupancy. Currently, only the maximum
|
||||
// occupancy version of GetCuda3DLaunchConfig is available.
|
||||
//
|
||||
// For large number of work elements, the convention is that each kernel would
|
||||
// iterate through its assigned range. The return value of GetCudaLaunchConfig
|
||||
// is struct CudaLaunchConfig, which contains all the information needed for the
|
||||
// kernel launch, including: virtual number of threads, the number of threads
|
||||
// per block and number of threads per block used inside <<< >>> of a kernel
|
||||
// launch. GetCuda2DLaunchConfig and GetCuda3DLaunchConfig does the same thing
|
||||
// as CudaLaunchConfig. The only difference is the dimension. The macros
|
||||
// CUDA_1D_KERNEL_LOOP and CUDA_AXIS_KERNEL_LOOP might be used to do inner loop.
|
||||
//
|
||||
/* Sample code:
|
||||
|
||||
__global__ void MyKernel1D(CudaLaunchConfig config, other_args...) {
|
||||
CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
|
||||
do_your_job_here;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void MyKernel2D(Cuda2DLaunchConfig config, other_args...) {
|
||||
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
|
||||
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
|
||||
do_your_job_here;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) {
|
||||
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
|
||||
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
|
||||
CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
|
||||
do_your_job_here;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MyDriverFunc(const GPUDevice &d) {
|
||||
// use heuristics
|
||||
CudaLaunchConfig cfg1 = GetCudaLaunchConfig(10240, d);
|
||||
MyKernel1D <<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>> (cfg1, other_args...);
|
||||
Cuda2DLaunchConfig cfg2 = GetCuda2DLaunchConfig(10240, 10240, d);
|
||||
MyKernel2D <<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>> (cfg2, other_args...);
|
||||
Cuda3DLaunchConfig cfg3 = GetCuda3DLaunchConfig(4096, 4096, 100, d);
|
||||
MyKernel3D <<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>> (cfg3, other_args...);
|
||||
|
||||
// maximize occupancy
|
||||
CudaLaunchConfig cfg4 = GetCudaLaunchConfig(10240, d, MyKernel1D, 0, 0 );
|
||||
MyKernel1D <<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>> (cfg4, other_args...);
|
||||
Cuda2DLaunchConfig cfg5 = GetCuda2DLaunchConfig(10240, 10240, d,
|
||||
MyKernel1D, 0, 0);
|
||||
MyKernel2D <<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>> (cfg5, other_args...);
|
||||
Cuda3DLaunchConfig cfg6 = GetCuda3DLaunchConfig(4096, 4096, 100, d,
|
||||
MyKernel1D, 0, 0);
|
||||
MyKernel3D <<<config.block_count,
|
||||
config.thread_per_block, 0, d.stream()>>> (cfg6, other_args...);
|
||||
}
|
||||
|
||||
// See the test for this for more example:
|
||||
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
|
||||
|
||||
*/
|
||||
|
||||
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
#define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \
|
||||
for (int i = blockIdx.axis * blockDim.axis + threadIdx.axis; i < n.axis; \
|
||||
i += blockDim.axis * gridDim.axis)
|
||||
|
||||
#define DIV_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
|
@ -47,16 +129,22 @@ struct CudaLaunchConfig {
|
|||
// memory-limited.
|
||||
inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
|
||||
const GPUDevice& d) {
|
||||
CudaLaunchConfig config;
|
||||
|
||||
// in case of invalid input, return the default value config, which has all -1
|
||||
if (work_element_count <= 0) {
|
||||
return config;
|
||||
}
|
||||
|
||||
const int virtual_thread_count = work_element_count;
|
||||
const int physical_thread_count = std::min(
|
||||
d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(),
|
||||
virtual_thread_count);
|
||||
const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock());
|
||||
const int block_count = std::min(
|
||||
(physical_thread_count + thread_per_block - 1) / thread_per_block,
|
||||
const int block_count =
|
||||
std::min(DIV_UP(physical_thread_count, thread_per_block),
|
||||
d.getNumCudaMultiProcessors());
|
||||
|
||||
CudaLaunchConfig config;
|
||||
config.virtual_thread_count = virtual_thread_count;
|
||||
config.thread_per_block = thread_per_block;
|
||||
config.block_count = block_count;
|
||||
|
|
@ -70,16 +158,23 @@ inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
|
|||
const GPUDevice& d, DeviceFunc func,
|
||||
size_t dynamic_shared_memory_size,
|
||||
int block_size_limit) {
|
||||
CudaLaunchConfig config;
|
||||
|
||||
if (work_element_count <= 0) {
|
||||
return config;
|
||||
}
|
||||
|
||||
int block_count = 0;
|
||||
int thread_per_block = 0;
|
||||
cudaOccupancyMaxPotentialBlockSize(&block_count, &thread_per_block, func,
|
||||
dynamic_shared_memory_size,
|
||||
block_size_limit);
|
||||
block_count =
|
||||
std::min(block_count,
|
||||
(work_element_count + thread_per_block - 1) / thread_per_block);
|
||||
|
||||
CudaLaunchConfig config;
|
||||
cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
|
||||
&block_count, &thread_per_block, func, dynamic_shared_memory_size,
|
||||
block_size_limit);
|
||||
CHECK_EQ(err, cudaSuccess);
|
||||
|
||||
block_count =
|
||||
std::min(block_count, DIV_UP(work_element_count, thread_per_block));
|
||||
|
||||
config.virtual_thread_count = work_element_count;
|
||||
config.thread_per_block = thread_per_block;
|
||||
config.block_count = block_count;
|
||||
|
|
@ -87,16 +182,18 @@ inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
|
|||
}
|
||||
|
||||
struct Cuda2DLaunchConfig {
|
||||
dim3 virtual_thread_count;
|
||||
dim3 thread_per_block;
|
||||
dim3 block_count;
|
||||
dim3 virtual_thread_count = dim3(0, 0, 0);
|
||||
dim3 thread_per_block = dim3(0, 0, 0);
|
||||
dim3 block_count = dim3(0, 0, 0);
|
||||
};
|
||||
|
||||
inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
|
||||
const GPUDevice& d) {
|
||||
Cuda2DLaunchConfig config;
|
||||
|
||||
config.virtual_thread_count = dim3(xdim, ydim, 1);
|
||||
if (xdim <= 0 || ydim <= 0) {
|
||||
return config;
|
||||
}
|
||||
|
||||
const int kThreadsPerBlock = 256;
|
||||
int block_cols = std::min(xdim, kThreadsPerBlock);
|
||||
|
|
@ -108,16 +205,78 @@ inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
|
|||
|
||||
const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);
|
||||
|
||||
config.virtual_thread_count = dim3(xdim, ydim, 1);
|
||||
config.thread_per_block = dim3(block_cols, block_rows, 1);
|
||||
|
||||
int grid_x = std::min((xdim + block_cols - 1) / block_cols, max_blocks);
|
||||
int grid_x = std::min(DIV_UP(xdim, block_cols), max_blocks);
|
||||
|
||||
config.block_count = dim3(
|
||||
grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1);
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
// Calculate the Cuda 2D and 3D launch config we should use for a kernel launch.
|
||||
// This variant takes the resource limits of func into account to maximize
|
||||
// occupancy.
|
||||
using Cuda3DLaunchConfig = Cuda2DLaunchConfig;
|
||||
|
||||
template <typename DeviceFunc>
|
||||
inline Cuda3DLaunchConfig GetCuda3DLaunchConfig(
|
||||
int xdim, int ydim, int zdim, const GPUDevice& d, DeviceFunc func,
|
||||
size_t dynamic_shared_memory_size, int block_size_limit) {
|
||||
Cuda3DLaunchConfig config;
|
||||
|
||||
if (xdim <= 0 || ydim <= 0 || zdim <= 0) {
|
||||
return config;
|
||||
}
|
||||
|
||||
int dev;
|
||||
cudaGetDevice(&dev);
|
||||
cudaDeviceProp deviceProp;
|
||||
cudaGetDeviceProperties(&deviceProp, dev);
|
||||
int xthreadlimit = deviceProp.maxThreadsDim[0];
|
||||
int ythreadlimit = deviceProp.maxThreadsDim[1];
|
||||
int zthreadlimit = deviceProp.maxThreadsDim[2];
|
||||
int xgridlimit = deviceProp.maxGridSize[0];
|
||||
int ygridlimit = deviceProp.maxGridSize[1];
|
||||
int zgridlimit = deviceProp.maxGridSize[2];
|
||||
|
||||
int block_count = 0;
|
||||
int thread_per_block = 0;
|
||||
cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
|
||||
&block_count, &thread_per_block, func, dynamic_shared_memory_size,
|
||||
block_size_limit);
|
||||
CHECK_EQ(err, cudaSuccess);
|
||||
|
||||
#define MIN3(a, b, c) std::min((a), std::min((b), (c)))
|
||||
int threadsx = MIN3(xdim, thread_per_block, xthreadlimit);
|
||||
int threadsy =
|
||||
MIN3(ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit);
|
||||
int threadsz =
|
||||
MIN3(zdim, std::max(thread_per_block / (threadsx * threadsy), 1),
|
||||
zthreadlimit);
|
||||
|
||||
int blocksx = MIN3(block_count, DIV_UP(xdim, threadsx), xgridlimit);
|
||||
int blocksy =
|
||||
MIN3(DIV_UP(block_count, blocksx), DIV_UP(ydim, threadsy), ygridlimit);
|
||||
int blocksz = MIN3(DIV_UP(block_count, (blocksx * blocksy)),
|
||||
DIV_UP(zdim, threadsz), zgridlimit);
|
||||
#undef MIN3
|
||||
|
||||
config.virtual_thread_count = dim3(xdim, ydim, zdim);
|
||||
config.thread_per_block = dim3(threadsx, threadsy, threadsz);
|
||||
config.block_count = dim3(blocksx, blocksy, blocksz);
|
||||
return config;
|
||||
}
|
||||
|
||||
template <typename DeviceFunc>
|
||||
inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(
|
||||
int xdim, int ydim, const GPUDevice& d, DeviceFunc func,
|
||||
size_t dynamic_shared_memory_size, int block_size_limit) {
|
||||
return GetCuda3DLaunchConfig(xdim, ydim, 1, d, func,
|
||||
dynamic_shared_memory_size, block_size_limit);
|
||||
}
|
||||
|
||||
namespace gpu {
|
||||
|
||||
template <typename IntType>
|
||||
|
|
@ -511,6 +670,8 @@ __device__ EIGEN_ALWAYS_INLINE double CudaShuffleXor(double value, int laneMask,
|
|||
|
||||
} // namespace tensorflow
|
||||
|
||||
#undef DIV_UP
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
|
||||
|
|
|
|||
303
tensorflow/core/util/cuda_kernel_helper_test.cu.cc
Normal file
303
tensorflow/core/util/cuda_kernel_helper_test.cu.cc
Normal file
|
|
@ -0,0 +1,303 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include <numeric>
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
|
||||
#define CUDA_EXPECT_SUCCESS \
|
||||
{ \
|
||||
cudaDeviceSynchronize(); \
|
||||
cudaError_t err = cudaGetLastError(); \
|
||||
EXPECT_EQ(cudaSuccess, err) << cudaGetErrorString(err); \
|
||||
}
|
||||
|
||||
#define CUDA_ASSERT_SUCCESS \
|
||||
{ \
|
||||
cudaDeviceSynchronize(); \
|
||||
cudaError_t err = cudaGetLastError(); \
|
||||
ASSERT_EQ(cudaSuccess, err) << cudaGetErrorString(err); \
|
||||
}
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
__global__ void SetOutbufZero(CudaLaunchConfig config, int* outbuf) {
|
||||
CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) { outbuf[x] = 0; }
|
||||
}
|
||||
|
||||
// counting number of jobs by using atomic +1
|
||||
__global__ void Count1D(CudaLaunchConfig config, int bufsize, int* outbuf) {
|
||||
CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
|
||||
if (x < 0) { // x might overflow when testing extreme case
|
||||
break;
|
||||
}
|
||||
atomicAdd(&outbuf[x % bufsize], 1);
|
||||
}
|
||||
}
|
||||
__global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) {
|
||||
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
|
||||
if (x < 0) { // x might overflow when testing extreme case
|
||||
break;
|
||||
}
|
||||
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
|
||||
if (y < 0) { // y might overflow when testing extreme case
|
||||
break;
|
||||
}
|
||||
int idx = x * config.virtual_thread_count.y + y;
|
||||
atomicAdd(&outbuf[idx % bufsize], 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
__global__ void Count3D(Cuda3DLaunchConfig config, int bufsize, int* outbuf) {
|
||||
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
|
||||
if (x < 0) { // x might overflow when testing extreme case
|
||||
break;
|
||||
}
|
||||
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
|
||||
if (y < 0) { // y might overflow when testing extreme case
|
||||
break;
|
||||
}
|
||||
CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
|
||||
if (z < 0) { // z might overflow when testing extreme case
|
||||
break;
|
||||
}
|
||||
int idx =
|
||||
x * config.virtual_thread_count.y * config.virtual_thread_count.z +
|
||||
y * config.virtual_thread_count.z + z;
|
||||
atomicAdd(&outbuf[idx % bufsize], 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class CudaLaunchConfigTest : public ::testing::Test {
|
||||
protected:
|
||||
const int bufsize = 1024;
|
||||
int* outbuf = nullptr;
|
||||
Eigen::CudaStreamDevice stream;
|
||||
GPUDevice d = GPUDevice(&stream);
|
||||
|
||||
virtual void SetUp() {
|
||||
cudaError_t err = cudaMallocManaged(&outbuf, sizeof(int) * bufsize);
|
||||
ASSERT_EQ(cudaSuccess, err) << cudaGetErrorString(err);
|
||||
}
|
||||
|
||||
virtual void TearDown() {
|
||||
cudaDeviceSynchronize();
|
||||
cudaFree(outbuf);
|
||||
outbuf = nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(CudaLaunchConfigTest, GetCudaLaunchConfig) {
|
||||
CudaLaunchConfig cfg;
|
||||
|
||||
// test invalid inputs
|
||||
CudaLaunchConfig default_value;
|
||||
cfg = GetCudaLaunchConfig(0, d);
|
||||
EXPECT_EQ(default_value.virtual_thread_count, cfg.virtual_thread_count);
|
||||
EXPECT_EQ(default_value.block_count, cfg.block_count);
|
||||
EXPECT_EQ(default_value.thread_per_block, cfg.thread_per_block);
|
||||
|
||||
cfg = GetCudaLaunchConfig(-1, d);
|
||||
EXPECT_EQ(default_value.virtual_thread_count, cfg.virtual_thread_count);
|
||||
EXPECT_EQ(default_value.block_count, cfg.block_count);
|
||||
EXPECT_EQ(default_value.thread_per_block, cfg.thread_per_block);
|
||||
|
||||
cfg = GetCudaLaunchConfig(0, d, Count1D, 0, 0);
|
||||
EXPECT_EQ(default_value.virtual_thread_count, cfg.virtual_thread_count);
|
||||
EXPECT_EQ(default_value.block_count, cfg.block_count);
|
||||
EXPECT_EQ(default_value.thread_per_block, cfg.thread_per_block);
|
||||
|
||||
cfg = GetCudaLaunchConfig(-1, d, Count1D, 0, 0);
|
||||
EXPECT_EQ(default_value.virtual_thread_count, cfg.virtual_thread_count);
|
||||
EXPECT_EQ(default_value.block_count, cfg.block_count);
|
||||
EXPECT_EQ(default_value.thread_per_block, cfg.thread_per_block);
|
||||
|
||||
// test valid inputs
|
||||
#define TEST_LAUNCH_PARAMETER(work_element_count) \
|
||||
cfg = GetCudaLaunchConfig(bufsize, d); \
|
||||
SetOutbufZero<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>> \
|
||||
(cfg, outbuf); \
|
||||
CUDA_ASSERT_SUCCESS \
|
||||
cfg = GetCudaLaunchConfig(work_element_count, d); \
|
||||
Count1D<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>> ( \
|
||||
cfg, bufsize, outbuf); \
|
||||
CUDA_EXPECT_SUCCESS \
|
||||
EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0));\
|
||||
\
|
||||
cfg = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \
|
||||
SetOutbufZero<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>> \
|
||||
(cfg, outbuf); \
|
||||
CUDA_ASSERT_SUCCESS \
|
||||
cfg = GetCudaLaunchConfig(work_element_count, d, Count1D, 0, 0); \
|
||||
Count1D<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>> ( \
|
||||
cfg, bufsize, outbuf); \
|
||||
CUDA_EXPECT_SUCCESS \
|
||||
EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0))
|
||||
|
||||
TEST_LAUNCH_PARAMETER(128);
|
||||
TEST_LAUNCH_PARAMETER(129);
|
||||
TEST_LAUNCH_PARAMETER(511);
|
||||
TEST_LAUNCH_PARAMETER(512);
|
||||
TEST_LAUNCH_PARAMETER(2048);
|
||||
TEST_LAUNCH_PARAMETER(2049);
|
||||
TEST_LAUNCH_PARAMETER(8191);
|
||||
TEST_LAUNCH_PARAMETER(8192);
|
||||
TEST_LAUNCH_PARAMETER(123456);
|
||||
TEST_LAUNCH_PARAMETER(1 << 31 - 1); // max value of int
|
||||
#undef TEST_LAUNCH_PARAMETER
|
||||
}
|
||||
|
||||
bool operator==(const Cuda2DLaunchConfig& a, const Cuda2DLaunchConfig& b) {
|
||||
return a.thread_per_block.x == b.thread_per_block.x &&
|
||||
a.thread_per_block.y == b.thread_per_block.y &&
|
||||
a.thread_per_block.z == b.thread_per_block.z &&
|
||||
a.block_count.x == b.block_count.x &&
|
||||
a.block_count.y == b.block_count.y &&
|
||||
a.block_count.z == b.block_count.z &&
|
||||
a.thread_per_block.x == b.thread_per_block.x &&
|
||||
a.thread_per_block.y == b.thread_per_block.y &&
|
||||
a.thread_per_block.z == b.thread_per_block.z;
|
||||
}
|
||||
|
||||
TEST_F(CudaLaunchConfigTest, GetCuda2DLaunchConfig) {
|
||||
Cuda2DLaunchConfig cfg;
|
||||
CudaLaunchConfig cfg1d;
|
||||
|
||||
// test invalid inputs
|
||||
Cuda2DLaunchConfig default_value;
|
||||
cfg = GetCuda2DLaunchConfig(1, 0, d);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
cfg = GetCuda2DLaunchConfig(1, -1, d);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
cfg = GetCuda2DLaunchConfig(-1, 1, d);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
cfg = GetCuda2DLaunchConfig(-1, 1, d);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
cfg = GetCuda2DLaunchConfig(0, -1, d);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
cfg = GetCuda2DLaunchConfig(0, 0, d);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
|
||||
cfg = GetCuda2DLaunchConfig(1, 0, d, Count2D, 0, 0);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
cfg = GetCuda2DLaunchConfig(1, -1, d, Count2D, 0, 0);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
cfg = GetCuda2DLaunchConfig(-1, 1, d, Count2D, 0, 0);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
cfg = GetCuda2DLaunchConfig(-1, 1, d, Count2D, 0, 0);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
cfg = GetCuda2DLaunchConfig(0, -1, d, Count2D, 0, 0);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
cfg = GetCuda2DLaunchConfig(0, 0, d, Count2D, 0, 0);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
|
||||
// test valid inputs
|
||||
#define TEST_LAUNCH_PARAMETER(dimx, dimy) \
|
||||
cfg1d = GetCudaLaunchConfig(bufsize, d); \
|
||||
SetOutbufZero<<<cfg1d.block_count, cfg1d.thread_per_block, 0, d.stream()>>> \
|
||||
(cfg1d, outbuf);\
|
||||
CUDA_ASSERT_SUCCESS \
|
||||
cfg = GetCuda2DLaunchConfig(dimx, dimy, d); \
|
||||
Count2D<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>> ( \
|
||||
cfg, bufsize, outbuf); \
|
||||
CUDA_EXPECT_SUCCESS \
|
||||
EXPECT_EQ(dimx * dimy, std::accumulate(outbuf, outbuf + bufsize, 0)); \
|
||||
\
|
||||
cfg1d = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \
|
||||
SetOutbufZero<<<cfg1d.block_count, cfg1d.thread_per_block, 0, d.stream()>>> \
|
||||
(cfg1d, outbuf);\
|
||||
CUDA_ASSERT_SUCCESS \
|
||||
cfg = GetCuda2DLaunchConfig(dimx, dimy, d, Count2D, 0, 0); \
|
||||
Count2D<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>> ( \
|
||||
cfg, bufsize, outbuf); \
|
||||
CUDA_EXPECT_SUCCESS \
|
||||
EXPECT_EQ(dimx * dimy, std::accumulate(outbuf, outbuf + bufsize, 0))
|
||||
|
||||
TEST_LAUNCH_PARAMETER(128, 128);
|
||||
TEST_LAUNCH_PARAMETER(129, 64);
|
||||
TEST_LAUNCH_PARAMETER(511, 2048);
|
||||
TEST_LAUNCH_PARAMETER(512, 512);
|
||||
TEST_LAUNCH_PARAMETER(2048, 1024);
|
||||
TEST_LAUNCH_PARAMETER(2049, 32);
|
||||
TEST_LAUNCH_PARAMETER(8191, 1);
|
||||
TEST_LAUNCH_PARAMETER(8192, 10);
|
||||
TEST_LAUNCH_PARAMETER(123456, 12);
|
||||
TEST_LAUNCH_PARAMETER(1, (1 << 31 - 1));
|
||||
TEST_LAUNCH_PARAMETER((1 << 31 - 1), 1);
|
||||
#undef TEST_LAUNCH_PARAMETER
|
||||
}
|
||||
|
||||
TEST_F(CudaLaunchConfigTest, GetCuda3DLaunchConfig) {
|
||||
Cuda3DLaunchConfig cfg;
|
||||
CudaLaunchConfig cfg1d;
|
||||
|
||||
// test invalid inputs
|
||||
Cuda3DLaunchConfig default_value;
|
||||
cfg = GetCuda3DLaunchConfig(0, 1, 1, d, Count3D, 0, 0);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
cfg = GetCuda3DLaunchConfig(-1, 1, 1, d, Count3D, 0, 0);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
cfg = GetCuda3DLaunchConfig(1, 0, 1, d, Count3D, 0, 0);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
cfg = GetCuda3DLaunchConfig(1, -1, 1, d, Count3D, 0, 0);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
cfg = GetCuda3DLaunchConfig(1, 1, 0, d, Count3D, 0, 0);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
cfg = GetCuda3DLaunchConfig(1, 1, -1, d, Count3D, 0, 0);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
cfg = GetCuda3DLaunchConfig(0, 0, 0, d, Count3D, 0, 0);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
cfg = GetCuda3DLaunchConfig(-1, -1, -1, d, Count3D, 0, 0);
|
||||
EXPECT_EQ(default_value, cfg);
|
||||
|
||||
// test valid inputs
|
||||
#define TEST_LAUNCH_PARAMETER(dimx, dimy, dimz) \
|
||||
cfg1d = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \
|
||||
SetOutbufZero<<<cfg1d.block_count, cfg1d.thread_per_block, 0, d.stream()>>> \
|
||||
(cfg1d, outbuf);\
|
||||
CUDA_ASSERT_SUCCESS \
|
||||
cfg = GetCuda3DLaunchConfig(dimx, dimy, dimz, d, Count3D, 0, 0); \
|
||||
Count3D<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>> ( \
|
||||
cfg, bufsize, outbuf); \
|
||||
CUDA_EXPECT_SUCCESS \
|
||||
EXPECT_EQ(dimx * dimy * dimz, std::accumulate(outbuf, outbuf + bufsize, 0))
|
||||
|
||||
TEST_LAUNCH_PARAMETER(128, 128, 128);
|
||||
TEST_LAUNCH_PARAMETER(129, 64, 1024);
|
||||
TEST_LAUNCH_PARAMETER(511, 2048, 128);
|
||||
TEST_LAUNCH_PARAMETER(512, 512, 64);
|
||||
TEST_LAUNCH_PARAMETER(2048, 1024, 128);
|
||||
TEST_LAUNCH_PARAMETER(2049, 32, 1024);
|
||||
TEST_LAUNCH_PARAMETER(8191, 1, 1024);
|
||||
TEST_LAUNCH_PARAMETER(8192, 10, 32);
|
||||
TEST_LAUNCH_PARAMETER(123456, 12, 21);
|
||||
TEST_LAUNCH_PARAMETER(1, 1, (1 << 31 - 1));
|
||||
TEST_LAUNCH_PARAMETER(1, (1 << 31 - 1), 1);
|
||||
TEST_LAUNCH_PARAMETER((1 << 31 - 1), 1, 1);
|
||||
#undef TEST_LAUNCH_PARAMETER
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
|
@ -461,6 +461,29 @@ def tf_cuda_cc_test(name,
|
|||
linkopts=linkopts,
|
||||
args=args)
|
||||
|
||||
def tf_cuda_only_cc_test(name,
|
||||
srcs=[],
|
||||
deps=[],
|
||||
tags=[],
|
||||
data=[],
|
||||
size="medium",
|
||||
linkstatic=0,
|
||||
args=[],
|
||||
linkopts=[]):
|
||||
native.cc_test(
|
||||
name="%s%s" % (name, "_gpu"),
|
||||
srcs=srcs,
|
||||
size=size,
|
||||
args=args,
|
||||
copts= _cuda_copts() + tf_copts(),
|
||||
data=data,
|
||||
deps=deps + if_cuda([
|
||||
clean_dep("//tensorflow/core:cuda"),
|
||||
clean_dep("//tensorflow/core:gpu_lib"),
|
||||
]),
|
||||
linkopts=["-lpthread", "-lm"] + linkopts,
|
||||
linkstatic=linkstatic,
|
||||
tags=tags)
|
||||
|
||||
# Create a cc_test for each of the tensorflow tests listed in "tests"
|
||||
def tf_cc_tests(srcs,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user