mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Rollback PR #39429: [ROCm] Fixing and enabling TopK
PiperOrigin-RevId: 320655274 Change-Id: Ie7c4bfa6bdef5120f8fe4afd4cc26cb84efa39c0
This commit is contained in:
parent
08c65909e3
commit
195c91d7ba
|
|
@ -14,10 +14,6 @@ limitations under the license, the license you must see.
|
|||
#ifndef TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/cub/block/block_load.cuh"
|
||||
#include "third_party/cub/block/block_scan.cuh"
|
||||
|
|
@ -35,35 +31,10 @@ limitations under the license, the license you must see.
|
|||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
|
||||
namespace gpuprim = ::cub;
|
||||
|
||||
// Required for sorting Eigen::half
|
||||
namespace cub {
|
||||
template <>
|
||||
struct NumericTraits<Eigen::half>
|
||||
: BaseTraits<FLOATING_POINT, true, false, unsigned short, Eigen::half> {};
|
||||
|
||||
// Provide overload for CUB to assign to volatile Eigen::half.
|
||||
template <>
|
||||
__device__ __forceinline__ void ThreadStoreVolatilePtr<Eigen::half>(
|
||||
Eigen::half *ptr, Eigen::half val, Int2Type<true> /*is_primitive*/) {
|
||||
reinterpret_cast<volatile unsigned short &>(ptr->x) = val.x;
|
||||
}
|
||||
|
||||
// Provide overload for CUB to load from volatile Eigen::half.
|
||||
template <>
|
||||
__device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer<Eigen::half>(
|
||||
Eigen::half *ptr, Int2Type<true> /*is_primitive*/) {
|
||||
auto x = reinterpret_cast<const volatile unsigned short &>(ptr->x);
|
||||
return Eigen::half_impl::raw_uint16_to_half(x);
|
||||
}
|
||||
|
||||
} // namespace cub
|
||||
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "rocm/include/hipcub/hipcub.hpp"
|
||||
namespace gpuprim = ::hipcub;
|
||||
|
||||
// Required for sorting Eigen::half
|
||||
namespace rocprim {
|
||||
namespace detail {
|
||||
template <>
|
||||
|
|
@ -71,6 +42,6 @@ struct radix_key_codec_base<Eigen::half>
|
|||
: radix_key_codec_floating<Eigen::half, unsigned short> {};
|
||||
}; // namespace detail
|
||||
}; // namespace rocprim
|
||||
#endif // TENSORFLOW_USE_ROCM
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_
|
||||
|
|
|
|||
|
|
@ -76,9 +76,9 @@ static Graph* InTopK(int num_targets, int num_classes, T top_k) {
|
|||
BM_InTopK(int64, 64, 1000, 10, cpu);
|
||||
BM_InTopK(int64, 64, 10000, 10, cpu);
|
||||
|
||||
#if defined GOOGLE_CUDA || defined TENSORFLOW_USE_ROCM
|
||||
#ifdef GOOGLE_CUDA
|
||||
BM_InTopK(int64, 64, 1000, 10, gpu);
|
||||
BM_InTopK(int64, 64, 10000, 10, gpu);
|
||||
#endif // defined GOOGLE_CUDA || defined TENSORFLOW_USE_ROCM
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -244,7 +244,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
|
|||
#undef REGISTER_KERNELS_NAME
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#ifdef GOOGLE_CUDA
|
||||
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
|
|
@ -277,6 +277,6 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
|
|||
TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
#endif // end GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#endif // end GOOGLE_CUDA
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -15,12 +15,11 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
|
@ -35,6 +34,15 @@ limitations under the License.
|
|||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
// Required for sorting Eigen::half
|
||||
namespace cub {
|
||||
template <>
|
||||
struct NumericTraits<Eigen::half>
|
||||
: BaseTraits<FLOATING_POINT, true, false, unsigned short, Eigen::half> {};
|
||||
} // namespace cub
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
|
@ -85,7 +93,7 @@ struct IndirectLinearData {
|
|||
Entry* const backing_data;
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if GOOGLE_CUDA
|
||||
template <typename T>
|
||||
struct StridedData {
|
||||
typedef impl::Entry<T> Entry;
|
||||
|
|
@ -107,7 +115,6 @@ template <HeapType heapType, PreferIndices preferIndices,
|
|||
struct IndexedHeap {
|
||||
typedef typename Data<T>::Entry Entry;
|
||||
const Data<T> data;
|
||||
__device__ IndexedHeap(const Data<T>& d) : data(d) {}
|
||||
|
||||
__device__ bool is_above(int left, int right) {
|
||||
T left_value = data.get_value(left);
|
||||
|
|
@ -330,18 +337,12 @@ __device__ void mergeShards(int num_shards, int k,
|
|||
}
|
||||
}
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
extern __shared__ char shared_memory[];
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__attribute__((amdgpu_flat_work_group_size(1, 256))) __global__ void TopKKernel(
|
||||
const T* __restrict__ input, int length, int k, bool sorted,
|
||||
T* __restrict__ output, int* __restrict__ indices) {
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
HIP_DYNAMIC_SHARED(char, shared_memory);
|
||||
#endif
|
||||
|
||||
__global__ void TopKKernel(const T* __restrict__ input, int length, int k,
|
||||
bool sorted, T* __restrict__ output,
|
||||
int* __restrict__ indices) {
|
||||
const int batch_index = blockIdx.x;
|
||||
const T* batch_input = input + batch_index * length;
|
||||
|
||||
|
|
@ -369,7 +370,7 @@ __attribute__((amdgpu_flat_work_group_size(1, 256))) __global__ void TopKKernel(
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
cudaError LaunchTopKKernel(const gpuStream_t& stream, int num_shards,
|
||||
cudaError LaunchTopKKernel(const cudaStream_t& stream, int num_shards,
|
||||
const T* input, int batch_size, int length, int k,
|
||||
bool sorted, T* output, int* indices) {
|
||||
// This code assumes that k is small enough that the computation
|
||||
|
|
@ -394,17 +395,9 @@ cudaError LaunchTopKKernel(const gpuStream_t& stream, int num_shards,
|
|||
}
|
||||
if (num_shards <= 0) {
|
||||
num_shards = 1;
|
||||
#if GOOGLE_CUDA
|
||||
} else if (num_shards > 1024) {
|
||||
num_shards = 1024;
|
||||
}
|
||||
#else
|
||||
// ROCm can't execute with 1024 and requires an explicit
|
||||
// amdgpu_flat_work_group_size attribute with >256
|
||||
} else if (num_shards > 256) {
|
||||
num_shards = 256;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
// We are limited by the amount of shared memory we have per block.
|
||||
auto shared_memory_size = (num_shards + 1) * k * sizeof(Entry<T>);
|
||||
|
|
@ -455,9 +448,9 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
|
|||
input_indices_t.device(d) =
|
||||
input_indices_t.generate(ColumnIndexCreator(num_cols));
|
||||
|
||||
gpuprim::CountingInputIterator<int> counting_iter(0);
|
||||
gpuprim::TransformInputIterator<int, SegmentOffsetCreator,
|
||||
gpuprim::CountingInputIterator<int>>
|
||||
cub::CountingInputIterator<int> counting_iter(0);
|
||||
cub::TransformInputIterator<int, SegmentOffsetCreator,
|
||||
cub::CountingInputIterator<int>>
|
||||
segment_offsets_t(counting_iter, SegmentOffsetCreator(num_cols));
|
||||
|
||||
Tensor temp_values;
|
||||
|
|
@ -479,7 +472,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
|
|||
sorted_values_ptr = temp_values.flat<T>().data();
|
||||
}
|
||||
|
||||
auto err = gpuprim::DeviceSegmentedRadixSort::SortPairsDescending(
|
||||
auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
|
||||
/* d_temp_storage */ nullptr,
|
||||
/* temp_storage_bytes */ temp_storage_bytes,
|
||||
/* d_keys_in */ input,
|
||||
|
|
@ -496,8 +489,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
|
|||
if (err != cudaSuccess) {
|
||||
return errors::Internal(
|
||||
"TopKOp: Could not launch "
|
||||
"cub::gpuprim::DeviceSegmentedRadixSort::SortPairsDescending to "
|
||||
"calculate "
|
||||
"cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate "
|
||||
"temp_storage_bytes, status: ",
|
||||
cudaGetErrorString(err));
|
||||
}
|
||||
|
|
@ -505,7 +497,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
|
|||
TF_RETURN_IF_ERROR(ctx->allocate_temp(
|
||||
DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
|
||||
&temp_storage));
|
||||
err = gpuprim::DeviceSegmentedRadixSort::SortPairsDescending(
|
||||
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
|
||||
/* d_temp_storage */ temp_storage.flat<int8>().data(),
|
||||
/* temp_storage_bytes */ temp_storage_bytes,
|
||||
/* d_keys_in */ input,
|
||||
|
|
@ -522,8 +514,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
|
|||
if (err != cudaSuccess) {
|
||||
return errors::Internal(
|
||||
"TopKOp: Could not launch "
|
||||
"cub::gpuprim::DeviceSegmentedRadixSort::SortPairsDescending to sort "
|
||||
"input, "
|
||||
"cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, "
|
||||
"temp_storage_bytes: ",
|
||||
temp_storage_bytes, ", status: ", cudaGetErrorString(err));
|
||||
}
|
||||
|
|
@ -576,6 +567,6 @@ struct TopKFunctor<GPUDevice, T> {
|
|||
} // end namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/topk_op.h"
|
||||
|
|
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
|
|||
template struct functor::TopKFunctor<GPUDevice, double>;
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#endif // GOOGLE_CUDA
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/topk_op.h"
|
||||
|
|
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
|
|||
template struct functor::TopKFunctor<GPUDevice, float>;
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#endif // GOOGLE_CUDA
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/topk_op.h"
|
||||
|
|
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
|
|||
template struct functor::TopKFunctor<GPUDevice, Eigen::half>;
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#endif // GOOGLE_CUDA
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/topk_op.h"
|
||||
|
|
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
|
|||
template struct functor::TopKFunctor<GPUDevice, int16>;
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#endif // GOOGLE_CUDA
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/topk_op.h"
|
||||
|
|
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
|
|||
template struct functor::TopKFunctor<GPUDevice, int32>;
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#endif // GOOGLE_CUDA
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/topk_op.h"
|
||||
|
|
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
|
|||
template struct functor::TopKFunctor<GPUDevice, int64>;
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#endif // GOOGLE_CUDA
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/topk_op.h"
|
||||
|
|
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
|
|||
template struct functor::TopKFunctor<GPUDevice, int8>;
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#endif // GOOGLE_CUDA
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/topk_op.h"
|
||||
|
|
@ -27,4 +27,4 @@ template struct functor::TopKFunctor<GPUDevice, uint32>;
|
|||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#endif // GOOGLE_CUDA
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/topk_op.h"
|
||||
|
|
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
|
|||
template struct functor::TopKFunctor<GPUDevice, uint8>;
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#endif // GOOGLE_CUDA
|
||||
|
|
|
|||
|
|
@ -102,13 +102,11 @@ class TopKTest(test.TestCase):
|
|||
self._validateTopK(inputs, 2, [[0.4, 0.3], [0.4, 0.3]], [[3, 1], [2, 1]])
|
||||
|
||||
def testTop3(self):
|
||||
for k in range(3, 11, 2):
|
||||
for dim in range(512, 12288, 512):
|
||||
inputs = np.random.permutation(
|
||||
np.linspace(0, 100, dim, dtype=np.float64))
|
||||
indices = np.argsort(-inputs)[:k]
|
||||
values = -np.sort(-inputs)[:k]
|
||||
self._validateTopK(inputs, k, values, indices)
|
||||
k = 5
|
||||
inputs = np.random.permutation(np.linspace(0, 100, 6140, dtype=np.float64))
|
||||
indices = np.argsort(-inputs)[:k]
|
||||
values = -np.sort(-inputs)[:k]
|
||||
self._validateTopK(inputs, k, values, indices)
|
||||
|
||||
def testTop1AllNan(self):
|
||||
inputs = [[np.NaN, np.NaN], [np.NaN, np.NaN]]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user