add Cuda{2D,3D}LaunchConfig that maximizes occupancy (#10032)

* add Cuda{2D,3D}LaunchConfig that max occupancy

* remove default val, check input<=0

* add max size check

* fix typo

* tests, docs, and related changes

* build the test

* buildify

* cudaOccupancy... call check success, and style fix
This commit is contained in:
Gao, Xiang 2017-06-06 15:33:15 -04:00 committed by Rasmus Munk Larsen
parent 187d233374
commit b440abce7f
4 changed files with 518 additions and 18 deletions

View File

@ -82,6 +82,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl")
load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu")
load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule")
load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test")
# For platform specific build config
load(
@ -2323,6 +2324,18 @@ tf_cc_test_gpu(
],
)
tf_cuda_only_cc_test(
name = "util_cuda_kernel_helper_test",
srcs = [
"util/cuda_kernel_helper_test.cu.cc",
],
deps = [
":test",
":test_main",
"//third_party/eigen3",
],
)
tf_cc_test_gpu(
name = "memory_types_test",
size = "small",

View File

@ -20,13 +20,95 @@ limitations under the License.
#include <algorithm>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/platform/default/logging.h"
#include "tensorflow/core/platform/types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and
// GetCuda3DLaunchConfig:
//
// There are two versions of GetCudaLaunchConfig and GetCuda2DLaunchConfig, one
// version uses heuristics without any knowledge of the device kernel, the other
// version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical
// launch parameters that maximize occupancy. Currently, only the maximum
// occupancy version of GetCuda3DLaunchConfig is available.
//
// For large number of work elements, the convention is that each kernel would
// iterate through its assigned range. The return value of GetCudaLaunchConfig
// is struct CudaLaunchConfig, which contains all the information needed for the
// kernel launch, including: virtual number of threads, the number of threads
// per block and number of threads per block used inside <<< >>> of a kernel
// launch. GetCuda2DLaunchConfig and GetCuda3DLaunchConfig does the same thing
// as CudaLaunchConfig. The only difference is the dimension. The macros
// CUDA_1D_KERNEL_LOOP and CUDA_AXIS_KERNEL_LOOP might be used to do inner loop.
//
/* Sample code:
__global__ void MyKernel1D(CudaLaunchConfig config, other_args...) {
CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
do_your_job_here;
}
}
__global__ void MyKernel2D(Cuda2DLaunchConfig config, other_args...) {
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
do_your_job_here;
}
}
}
__global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) {
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
do_your_job_here;
}
}
}
}
void MyDriverFunc(const GPUDevice &d) {
// use heuristics
CudaLaunchConfig cfg1 = GetCudaLaunchConfig(10240, d);
MyKernel1D <<<config.block_count,
config.thread_per_block, 0, d.stream()>>> (cfg1, other_args...);
Cuda2DLaunchConfig cfg2 = GetCuda2DLaunchConfig(10240, 10240, d);
MyKernel2D <<<config.block_count,
config.thread_per_block, 0, d.stream()>>> (cfg2, other_args...);
Cuda3DLaunchConfig cfg3 = GetCuda3DLaunchConfig(4096, 4096, 100, d);
MyKernel3D <<<config.block_count,
config.thread_per_block, 0, d.stream()>>> (cfg3, other_args...);
// maximize occupancy
CudaLaunchConfig cfg4 = GetCudaLaunchConfig(10240, d, MyKernel1D, 0, 0 );
MyKernel1D <<<config.block_count,
config.thread_per_block, 0, d.stream()>>> (cfg4, other_args...);
Cuda2DLaunchConfig cfg5 = GetCuda2DLaunchConfig(10240, 10240, d,
MyKernel1D, 0, 0);
MyKernel2D <<<config.block_count,
config.thread_per_block, 0, d.stream()>>> (cfg5, other_args...);
Cuda3DLaunchConfig cfg6 = GetCuda3DLaunchConfig(4096, 4096, 100, d,
MyKernel1D, 0, 0);
MyKernel3D <<<config.block_count,
config.thread_per_block, 0, d.stream()>>> (cfg6, other_args...);
}
// See the test for this for more example:
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
*/
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
#define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \
for (int i = blockIdx.axis * blockDim.axis + threadIdx.axis; i < n.axis; \
i += blockDim.axis * gridDim.axis)
#define DIV_UP(a, b) (((a) + (b) - 1) / (b))
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
@ -47,16 +129,22 @@ struct CudaLaunchConfig {
// memory-limited.
inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
const GPUDevice& d) {
CudaLaunchConfig config;
// in case of invalid input, return the default value config, which has all -1
if (work_element_count <= 0) {
return config;
}
const int virtual_thread_count = work_element_count;
const int physical_thread_count = std::min(
d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(),
virtual_thread_count);
const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock());
const int block_count = std::min(
(physical_thread_count + thread_per_block - 1) / thread_per_block,
const int block_count =
std::min(DIV_UP(physical_thread_count, thread_per_block),
d.getNumCudaMultiProcessors());
CudaLaunchConfig config;
config.virtual_thread_count = virtual_thread_count;
config.thread_per_block = thread_per_block;
config.block_count = block_count;
@ -70,16 +158,23 @@ inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
const GPUDevice& d, DeviceFunc func,
size_t dynamic_shared_memory_size,
int block_size_limit) {
CudaLaunchConfig config;
if (work_element_count <= 0) {
return config;
}
int block_count = 0;
int thread_per_block = 0;
cudaOccupancyMaxPotentialBlockSize(&block_count, &thread_per_block, func,
dynamic_shared_memory_size,
block_size_limit);
block_count =
std::min(block_count,
(work_element_count + thread_per_block - 1) / thread_per_block);
CudaLaunchConfig config;
cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
&block_count, &thread_per_block, func, dynamic_shared_memory_size,
block_size_limit);
CHECK_EQ(err, cudaSuccess);
block_count =
std::min(block_count, DIV_UP(work_element_count, thread_per_block));
config.virtual_thread_count = work_element_count;
config.thread_per_block = thread_per_block;
config.block_count = block_count;
@ -87,16 +182,18 @@ inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
}
struct Cuda2DLaunchConfig {
dim3 virtual_thread_count;
dim3 thread_per_block;
dim3 block_count;
dim3 virtual_thread_count = dim3(0, 0, 0);
dim3 thread_per_block = dim3(0, 0, 0);
dim3 block_count = dim3(0, 0, 0);
};
inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
const GPUDevice& d) {
Cuda2DLaunchConfig config;
config.virtual_thread_count = dim3(xdim, ydim, 1);
if (xdim <= 0 || ydim <= 0) {
return config;
}
const int kThreadsPerBlock = 256;
int block_cols = std::min(xdim, kThreadsPerBlock);
@ -108,16 +205,78 @@ inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);
config.virtual_thread_count = dim3(xdim, ydim, 1);
config.thread_per_block = dim3(block_cols, block_rows, 1);
int grid_x = std::min((xdim + block_cols - 1) / block_cols, max_blocks);
int grid_x = std::min(DIV_UP(xdim, block_cols), max_blocks);
config.block_count = dim3(
grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1);
return config;
}
// Calculate the Cuda 2D and 3D launch config we should use for a kernel launch.
// This variant takes the resource limits of func into account to maximize
// occupancy.
using Cuda3DLaunchConfig = Cuda2DLaunchConfig;
template <typename DeviceFunc>
inline Cuda3DLaunchConfig GetCuda3DLaunchConfig(
int xdim, int ydim, int zdim, const GPUDevice& d, DeviceFunc func,
size_t dynamic_shared_memory_size, int block_size_limit) {
Cuda3DLaunchConfig config;
if (xdim <= 0 || ydim <= 0 || zdim <= 0) {
return config;
}
int dev;
cudaGetDevice(&dev);
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, dev);
int xthreadlimit = deviceProp.maxThreadsDim[0];
int ythreadlimit = deviceProp.maxThreadsDim[1];
int zthreadlimit = deviceProp.maxThreadsDim[2];
int xgridlimit = deviceProp.maxGridSize[0];
int ygridlimit = deviceProp.maxGridSize[1];
int zgridlimit = deviceProp.maxGridSize[2];
int block_count = 0;
int thread_per_block = 0;
cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
&block_count, &thread_per_block, func, dynamic_shared_memory_size,
block_size_limit);
CHECK_EQ(err, cudaSuccess);
#define MIN3(a, b, c) std::min((a), std::min((b), (c)))
int threadsx = MIN3(xdim, thread_per_block, xthreadlimit);
int threadsy =
MIN3(ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit);
int threadsz =
MIN3(zdim, std::max(thread_per_block / (threadsx * threadsy), 1),
zthreadlimit);
int blocksx = MIN3(block_count, DIV_UP(xdim, threadsx), xgridlimit);
int blocksy =
MIN3(DIV_UP(block_count, blocksx), DIV_UP(ydim, threadsy), ygridlimit);
int blocksz = MIN3(DIV_UP(block_count, (blocksx * blocksy)),
DIV_UP(zdim, threadsz), zgridlimit);
#undef MIN3
config.virtual_thread_count = dim3(xdim, ydim, zdim);
config.thread_per_block = dim3(threadsx, threadsy, threadsz);
config.block_count = dim3(blocksx, blocksy, blocksz);
return config;
}
template <typename DeviceFunc>
inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(
int xdim, int ydim, const GPUDevice& d, DeviceFunc func,
size_t dynamic_shared_memory_size, int block_size_limit) {
return GetCuda3DLaunchConfig(xdim, ydim, 1, d, func,
dynamic_shared_memory_size, block_size_limit);
}
namespace gpu {
template <typename IntType>
@ -511,6 +670,8 @@ __device__ EIGEN_ALWAYS_INLINE double CudaShuffleXor(double value, int laneMask,
} // namespace tensorflow
#undef DIV_UP
#endif // GOOGLE_CUDA
#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_

View File

@ -0,0 +1,303 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include <numeric>
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#define CUDA_EXPECT_SUCCESS \
{ \
cudaDeviceSynchronize(); \
cudaError_t err = cudaGetLastError(); \
EXPECT_EQ(cudaSuccess, err) << cudaGetErrorString(err); \
}
#define CUDA_ASSERT_SUCCESS \
{ \
cudaDeviceSynchronize(); \
cudaError_t err = cudaGetLastError(); \
ASSERT_EQ(cudaSuccess, err) << cudaGetErrorString(err); \
}
namespace tensorflow {
namespace {
__global__ void SetOutbufZero(CudaLaunchConfig config, int* outbuf) {
CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) { outbuf[x] = 0; }
}
// counting number of jobs by using atomic +1
__global__ void Count1D(CudaLaunchConfig config, int bufsize, int* outbuf) {
CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
if (x < 0) { // x might overflow when testing extreme case
break;
}
atomicAdd(&outbuf[x % bufsize], 1);
}
}
__global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) {
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
if (x < 0) { // x might overflow when testing extreme case
break;
}
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
if (y < 0) { // y might overflow when testing extreme case
break;
}
int idx = x * config.virtual_thread_count.y + y;
atomicAdd(&outbuf[idx % bufsize], 1);
}
}
}
__global__ void Count3D(Cuda3DLaunchConfig config, int bufsize, int* outbuf) {
CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
if (x < 0) { // x might overflow when testing extreme case
break;
}
CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
if (y < 0) { // y might overflow when testing extreme case
break;
}
CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
if (z < 0) { // z might overflow when testing extreme case
break;
}
int idx =
x * config.virtual_thread_count.y * config.virtual_thread_count.z +
y * config.virtual_thread_count.z + z;
atomicAdd(&outbuf[idx % bufsize], 1);
}
}
}
}
} // namespace
class CudaLaunchConfigTest : public ::testing::Test {
protected:
const int bufsize = 1024;
int* outbuf = nullptr;
Eigen::CudaStreamDevice stream;
GPUDevice d = GPUDevice(&stream);
virtual void SetUp() {
cudaError_t err = cudaMallocManaged(&outbuf, sizeof(int) * bufsize);
ASSERT_EQ(cudaSuccess, err) << cudaGetErrorString(err);
}
virtual void TearDown() {
cudaDeviceSynchronize();
cudaFree(outbuf);
outbuf = nullptr;
}
};
TEST_F(CudaLaunchConfigTest, GetCudaLaunchConfig) {
CudaLaunchConfig cfg;
// test invalid inputs
CudaLaunchConfig default_value;
cfg = GetCudaLaunchConfig(0, d);
EXPECT_EQ(default_value.virtual_thread_count, cfg.virtual_thread_count);
EXPECT_EQ(default_value.block_count, cfg.block_count);
EXPECT_EQ(default_value.thread_per_block, cfg.thread_per_block);
cfg = GetCudaLaunchConfig(-1, d);
EXPECT_EQ(default_value.virtual_thread_count, cfg.virtual_thread_count);
EXPECT_EQ(default_value.block_count, cfg.block_count);
EXPECT_EQ(default_value.thread_per_block, cfg.thread_per_block);
cfg = GetCudaLaunchConfig(0, d, Count1D, 0, 0);
EXPECT_EQ(default_value.virtual_thread_count, cfg.virtual_thread_count);
EXPECT_EQ(default_value.block_count, cfg.block_count);
EXPECT_EQ(default_value.thread_per_block, cfg.thread_per_block);
cfg = GetCudaLaunchConfig(-1, d, Count1D, 0, 0);
EXPECT_EQ(default_value.virtual_thread_count, cfg.virtual_thread_count);
EXPECT_EQ(default_value.block_count, cfg.block_count);
EXPECT_EQ(default_value.thread_per_block, cfg.thread_per_block);
// test valid inputs
#define TEST_LAUNCH_PARAMETER(work_element_count) \
cfg = GetCudaLaunchConfig(bufsize, d); \
SetOutbufZero<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>> \
(cfg, outbuf); \
CUDA_ASSERT_SUCCESS \
cfg = GetCudaLaunchConfig(work_element_count, d); \
Count1D<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>> ( \
cfg, bufsize, outbuf); \
CUDA_EXPECT_SUCCESS \
EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0));\
\
cfg = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \
SetOutbufZero<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>> \
(cfg, outbuf); \
CUDA_ASSERT_SUCCESS \
cfg = GetCudaLaunchConfig(work_element_count, d, Count1D, 0, 0); \
Count1D<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>> ( \
cfg, bufsize, outbuf); \
CUDA_EXPECT_SUCCESS \
EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0))
TEST_LAUNCH_PARAMETER(128);
TEST_LAUNCH_PARAMETER(129);
TEST_LAUNCH_PARAMETER(511);
TEST_LAUNCH_PARAMETER(512);
TEST_LAUNCH_PARAMETER(2048);
TEST_LAUNCH_PARAMETER(2049);
TEST_LAUNCH_PARAMETER(8191);
TEST_LAUNCH_PARAMETER(8192);
TEST_LAUNCH_PARAMETER(123456);
TEST_LAUNCH_PARAMETER(1 << 31 - 1); // max value of int
#undef TEST_LAUNCH_PARAMETER
}
bool operator==(const Cuda2DLaunchConfig& a, const Cuda2DLaunchConfig& b) {
return a.thread_per_block.x == b.thread_per_block.x &&
a.thread_per_block.y == b.thread_per_block.y &&
a.thread_per_block.z == b.thread_per_block.z &&
a.block_count.x == b.block_count.x &&
a.block_count.y == b.block_count.y &&
a.block_count.z == b.block_count.z &&
a.thread_per_block.x == b.thread_per_block.x &&
a.thread_per_block.y == b.thread_per_block.y &&
a.thread_per_block.z == b.thread_per_block.z;
}
TEST_F(CudaLaunchConfigTest, GetCuda2DLaunchConfig) {
Cuda2DLaunchConfig cfg;
CudaLaunchConfig cfg1d;
// test invalid inputs
Cuda2DLaunchConfig default_value;
cfg = GetCuda2DLaunchConfig(1, 0, d);
EXPECT_EQ(default_value, cfg);
cfg = GetCuda2DLaunchConfig(1, -1, d);
EXPECT_EQ(default_value, cfg);
cfg = GetCuda2DLaunchConfig(-1, 1, d);
EXPECT_EQ(default_value, cfg);
cfg = GetCuda2DLaunchConfig(-1, 1, d);
EXPECT_EQ(default_value, cfg);
cfg = GetCuda2DLaunchConfig(0, -1, d);
EXPECT_EQ(default_value, cfg);
cfg = GetCuda2DLaunchConfig(0, 0, d);
EXPECT_EQ(default_value, cfg);
cfg = GetCuda2DLaunchConfig(1, 0, d, Count2D, 0, 0);
EXPECT_EQ(default_value, cfg);
cfg = GetCuda2DLaunchConfig(1, -1, d, Count2D, 0, 0);
EXPECT_EQ(default_value, cfg);
cfg = GetCuda2DLaunchConfig(-1, 1, d, Count2D, 0, 0);
EXPECT_EQ(default_value, cfg);
cfg = GetCuda2DLaunchConfig(-1, 1, d, Count2D, 0, 0);
EXPECT_EQ(default_value, cfg);
cfg = GetCuda2DLaunchConfig(0, -1, d, Count2D, 0, 0);
EXPECT_EQ(default_value, cfg);
cfg = GetCuda2DLaunchConfig(0, 0, d, Count2D, 0, 0);
EXPECT_EQ(default_value, cfg);
// test valid inputs
#define TEST_LAUNCH_PARAMETER(dimx, dimy) \
cfg1d = GetCudaLaunchConfig(bufsize, d); \
SetOutbufZero<<<cfg1d.block_count, cfg1d.thread_per_block, 0, d.stream()>>> \
(cfg1d, outbuf);\
CUDA_ASSERT_SUCCESS \
cfg = GetCuda2DLaunchConfig(dimx, dimy, d); \
Count2D<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>> ( \
cfg, bufsize, outbuf); \
CUDA_EXPECT_SUCCESS \
EXPECT_EQ(dimx * dimy, std::accumulate(outbuf, outbuf + bufsize, 0)); \
\
cfg1d = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \
SetOutbufZero<<<cfg1d.block_count, cfg1d.thread_per_block, 0, d.stream()>>> \
(cfg1d, outbuf);\
CUDA_ASSERT_SUCCESS \
cfg = GetCuda2DLaunchConfig(dimx, dimy, d, Count2D, 0, 0); \
Count2D<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>> ( \
cfg, bufsize, outbuf); \
CUDA_EXPECT_SUCCESS \
EXPECT_EQ(dimx * dimy, std::accumulate(outbuf, outbuf + bufsize, 0))
TEST_LAUNCH_PARAMETER(128, 128);
TEST_LAUNCH_PARAMETER(129, 64);
TEST_LAUNCH_PARAMETER(511, 2048);
TEST_LAUNCH_PARAMETER(512, 512);
TEST_LAUNCH_PARAMETER(2048, 1024);
TEST_LAUNCH_PARAMETER(2049, 32);
TEST_LAUNCH_PARAMETER(8191, 1);
TEST_LAUNCH_PARAMETER(8192, 10);
TEST_LAUNCH_PARAMETER(123456, 12);
TEST_LAUNCH_PARAMETER(1, (1 << 31 - 1));
TEST_LAUNCH_PARAMETER((1 << 31 - 1), 1);
#undef TEST_LAUNCH_PARAMETER
}
TEST_F(CudaLaunchConfigTest, GetCuda3DLaunchConfig) {
Cuda3DLaunchConfig cfg;
CudaLaunchConfig cfg1d;
// test invalid inputs
Cuda3DLaunchConfig default_value;
cfg = GetCuda3DLaunchConfig(0, 1, 1, d, Count3D, 0, 0);
EXPECT_EQ(default_value, cfg);
cfg = GetCuda3DLaunchConfig(-1, 1, 1, d, Count3D, 0, 0);
EXPECT_EQ(default_value, cfg);
cfg = GetCuda3DLaunchConfig(1, 0, 1, d, Count3D, 0, 0);
EXPECT_EQ(default_value, cfg);
cfg = GetCuda3DLaunchConfig(1, -1, 1, d, Count3D, 0, 0);
EXPECT_EQ(default_value, cfg);
cfg = GetCuda3DLaunchConfig(1, 1, 0, d, Count3D, 0, 0);
EXPECT_EQ(default_value, cfg);
cfg = GetCuda3DLaunchConfig(1, 1, -1, d, Count3D, 0, 0);
EXPECT_EQ(default_value, cfg);
cfg = GetCuda3DLaunchConfig(0, 0, 0, d, Count3D, 0, 0);
EXPECT_EQ(default_value, cfg);
cfg = GetCuda3DLaunchConfig(-1, -1, -1, d, Count3D, 0, 0);
EXPECT_EQ(default_value, cfg);
// test valid inputs
#define TEST_LAUNCH_PARAMETER(dimx, dimy, dimz) \
cfg1d = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \
SetOutbufZero<<<cfg1d.block_count, cfg1d.thread_per_block, 0, d.stream()>>> \
(cfg1d, outbuf);\
CUDA_ASSERT_SUCCESS \
cfg = GetCuda3DLaunchConfig(dimx, dimy, dimz, d, Count3D, 0, 0); \
Count3D<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>> ( \
cfg, bufsize, outbuf); \
CUDA_EXPECT_SUCCESS \
EXPECT_EQ(dimx * dimy * dimz, std::accumulate(outbuf, outbuf + bufsize, 0))
TEST_LAUNCH_PARAMETER(128, 128, 128);
TEST_LAUNCH_PARAMETER(129, 64, 1024);
TEST_LAUNCH_PARAMETER(511, 2048, 128);
TEST_LAUNCH_PARAMETER(512, 512, 64);
TEST_LAUNCH_PARAMETER(2048, 1024, 128);
TEST_LAUNCH_PARAMETER(2049, 32, 1024);
TEST_LAUNCH_PARAMETER(8191, 1, 1024);
TEST_LAUNCH_PARAMETER(8192, 10, 32);
TEST_LAUNCH_PARAMETER(123456, 12, 21);
TEST_LAUNCH_PARAMETER(1, 1, (1 << 31 - 1));
TEST_LAUNCH_PARAMETER(1, (1 << 31 - 1), 1);
TEST_LAUNCH_PARAMETER((1 << 31 - 1), 1, 1);
#undef TEST_LAUNCH_PARAMETER
}
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -461,6 +461,29 @@ def tf_cuda_cc_test(name,
linkopts=linkopts,
args=args)
def tf_cuda_only_cc_test(name,
srcs=[],
deps=[],
tags=[],
data=[],
size="medium",
linkstatic=0,
args=[],
linkopts=[]):
native.cc_test(
name="%s%s" % (name, "_gpu"),
srcs=srcs,
size=size,
args=args,
copts= _cuda_copts() + tf_copts(),
data=data,
deps=deps + if_cuda([
clean_dep("//tensorflow/core:cuda"),
clean_dep("//tensorflow/core:gpu_lib"),
]),
linkopts=["-lpthread", "-lm"] + linkopts,
linkstatic=linkstatic,
tags=tags)
# Create a cc_test for each of the tensorflow tests listed in "tests"
def tf_cc_tests(srcs,