Rollback PR #39429: [ROCm] Fixing and enabling TopK

PiperOrigin-RevId: 320655274
Change-Id: Ie7c4bfa6bdef5120f8fe4afd4cc26cb84efa39c0
This commit is contained in:
Mihai Maruseac 2020-07-10 12:39:45 -07:00 committed by TensorFlower Gardener
parent 08c65909e3
commit 195c91d7ba
14 changed files with 51 additions and 91 deletions

View File

@ -14,10 +14,6 @@ limitations under the license, the license you must see.
#ifndef TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_
#define TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#if GOOGLE_CUDA
#include "third_party/cub/block/block_load.cuh"
#include "third_party/cub/block/block_scan.cuh"
@ -35,35 +31,10 @@ limitations under the license, the license you must see.
#include "third_party/gpus/cuda/include/cusparse.h"
namespace gpuprim = ::cub;
// Required for sorting Eigen::half
namespace cub {
template <>
struct NumericTraits<Eigen::half>
: BaseTraits<FLOATING_POINT, true, false, unsigned short, Eigen::half> {};
// Provide overload for CUB to assign to volatile Eigen::half.
template <>
__device__ __forceinline__ void ThreadStoreVolatilePtr<Eigen::half>(
Eigen::half *ptr, Eigen::half val, Int2Type<true> /*is_primitive*/) {
reinterpret_cast<volatile unsigned short &>(ptr->x) = val.x;
}
// Provide overload for CUB to load from volatile Eigen::half.
template <>
__device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer<Eigen::half>(
Eigen::half *ptr, Int2Type<true> /*is_primitive*/) {
auto x = reinterpret_cast<const volatile unsigned short &>(ptr->x);
return Eigen::half_impl::raw_uint16_to_half(x);
}
} // namespace cub
#elif TENSORFLOW_USE_ROCM
#include "rocm/include/hipcub/hipcub.hpp"
namespace gpuprim = ::hipcub;
// Required for sorting Eigen::half
namespace rocprim {
namespace detail {
template <>
@ -71,6 +42,6 @@ struct radix_key_codec_base<Eigen::half>
: radix_key_codec_floating<Eigen::half, unsigned short> {};
}; // namespace detail
}; // namespace rocprim
#endif // TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA
#endif // TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_

View File

@ -76,9 +76,9 @@ static Graph* InTopK(int num_targets, int num_classes, T top_k) {
BM_InTopK(int64, 64, 1000, 10, cpu);
BM_InTopK(int64, 64, 10000, 10, cpu);
#if defined GOOGLE_CUDA || defined TENSORFLOW_USE_ROCM
#ifdef GOOGLE_CUDA
BM_InTopK(int64, 64, 1000, 10, gpu);
BM_InTopK(int64, 64, 10000, 10, gpu);
#endif // defined GOOGLE_CUDA || defined TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA
} // namespace tensorflow

View File

@ -244,7 +244,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS_NAME
#undef REGISTER_KERNELS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef GOOGLE_CUDA
namespace functor {
#define DECLARE_GPU_SPEC(T) \
@ -277,6 +277,6 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#endif // end GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // end GOOGLE_CUDA
} // end namespace tensorflow

View File

@ -15,12 +15,11 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
#define TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include <cmath>
#include <string>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@ -35,6 +34,15 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#if GOOGLE_CUDA
// Required for sorting Eigen::half
namespace cub {
template <>
struct NumericTraits<Eigen::half>
: BaseTraits<FLOATING_POINT, true, false, unsigned short, Eigen::half> {};
} // namespace cub
#endif // GOOGLE_CUDA
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
@ -85,7 +93,7 @@ struct IndirectLinearData {
Entry* const backing_data;
};
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
template <typename T>
struct StridedData {
typedef impl::Entry<T> Entry;
@ -107,7 +115,6 @@ template <HeapType heapType, PreferIndices preferIndices,
struct IndexedHeap {
typedef typename Data<T>::Entry Entry;
const Data<T> data;
__device__ IndexedHeap(const Data<T>& d) : data(d) {}
__device__ bool is_above(int left, int right) {
T left_value = data.get_value(left);
@ -330,18 +337,12 @@ __device__ void mergeShards(int num_shards, int k,
}
}
#if GOOGLE_CUDA
extern __shared__ char shared_memory[];
#endif
template <typename T>
__attribute__((amdgpu_flat_work_group_size(1, 256))) __global__ void TopKKernel(
const T* __restrict__ input, int length, int k, bool sorted,
T* __restrict__ output, int* __restrict__ indices) {
#if TENSORFLOW_USE_ROCM
HIP_DYNAMIC_SHARED(char, shared_memory);
#endif
__global__ void TopKKernel(const T* __restrict__ input, int length, int k,
bool sorted, T* __restrict__ output,
int* __restrict__ indices) {
const int batch_index = blockIdx.x;
const T* batch_input = input + batch_index * length;
@ -369,7 +370,7 @@ __attribute__((amdgpu_flat_work_group_size(1, 256))) __global__ void TopKKernel(
}
template <typename T>
cudaError LaunchTopKKernel(const gpuStream_t& stream, int num_shards,
cudaError LaunchTopKKernel(const cudaStream_t& stream, int num_shards,
const T* input, int batch_size, int length, int k,
bool sorted, T* output, int* indices) {
// This code assumes that k is small enough that the computation
@ -394,17 +395,9 @@ cudaError LaunchTopKKernel(const gpuStream_t& stream, int num_shards,
}
if (num_shards <= 0) {
num_shards = 1;
#if GOOGLE_CUDA
} else if (num_shards > 1024) {
num_shards = 1024;
}
#else
// ROCm can't execute with 1024 and requires an explicit
// amdgpu_flat_work_group_size attribute with >256
} else if (num_shards > 256) {
num_shards = 256;
}
#endif
}
// We are limited by the amount of shared memory we have per block.
auto shared_memory_size = (num_shards + 1) * k * sizeof(Entry<T>);
@ -455,9 +448,9 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
input_indices_t.device(d) =
input_indices_t.generate(ColumnIndexCreator(num_cols));
gpuprim::CountingInputIterator<int> counting_iter(0);
gpuprim::TransformInputIterator<int, SegmentOffsetCreator,
gpuprim::CountingInputIterator<int>>
cub::CountingInputIterator<int> counting_iter(0);
cub::TransformInputIterator<int, SegmentOffsetCreator,
cub::CountingInputIterator<int>>
segment_offsets_t(counting_iter, SegmentOffsetCreator(num_cols));
Tensor temp_values;
@ -479,7 +472,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
sorted_values_ptr = temp_values.flat<T>().data();
}
auto err = gpuprim::DeviceSegmentedRadixSort::SortPairsDescending(
auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
/* d_temp_storage */ nullptr,
/* temp_storage_bytes */ temp_storage_bytes,
/* d_keys_in */ input,
@ -496,8 +489,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
if (err != cudaSuccess) {
return errors::Internal(
"TopKOp: Could not launch "
"cub::gpuprim::DeviceSegmentedRadixSort::SortPairsDescending to "
"calculate "
"cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate "
"temp_storage_bytes, status: ",
cudaGetErrorString(err));
}
@ -505,7 +497,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
TF_RETURN_IF_ERROR(ctx->allocate_temp(
DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
&temp_storage));
err = gpuprim::DeviceSegmentedRadixSort::SortPairsDescending(
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
/* d_temp_storage */ temp_storage.flat<int8>().data(),
/* temp_storage_bytes */ temp_storage_bytes,
/* d_keys_in */ input,
@ -522,8 +514,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
if (err != cudaSuccess) {
return errors::Internal(
"TopKOp: Could not launch "
"cub::gpuprim::DeviceSegmentedRadixSort::SortPairsDescending to sort "
"input, "
"cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, "
"temp_storage_bytes: ",
temp_storage_bytes, ", status: ", cudaGetErrorString(err));
}
@ -576,6 +567,6 @@ struct TopKFunctor<GPUDevice, T> {
} // end namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA
#endif // TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, double>;
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, float>;
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, Eigen::half>;
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, int16>;
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, int32>;
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, int64>;
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, int8>;
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
@ -27,4 +27,4 @@ template struct functor::TopKFunctor<GPUDevice, uint32>;
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, uint8>;
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA

View File

@ -102,13 +102,11 @@ class TopKTest(test.TestCase):
self._validateTopK(inputs, 2, [[0.4, 0.3], [0.4, 0.3]], [[3, 1], [2, 1]])
def testTop3(self):
for k in range(3, 11, 2):
for dim in range(512, 12288, 512):
inputs = np.random.permutation(
np.linspace(0, 100, dim, dtype=np.float64))
indices = np.argsort(-inputs)[:k]
values = -np.sort(-inputs)[:k]
self._validateTopK(inputs, k, values, indices)
k = 5
inputs = np.random.permutation(np.linspace(0, 100, 6140, dtype=np.float64))
indices = np.argsort(-inputs)[:k]
values = -np.sort(-inputs)[:k]
self._validateTopK(inputs, k, values, indices)
def testTop1AllNan(self):
inputs = [[np.NaN, np.NaN], [np.NaN, np.NaN]]