mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
make gather cpu kernel to be multiple threads. (#12246)
* Change the gather op to multi-thread. * Modify the gather kernel of xla compiler in order to be compatible with multi-threads cpu kernel. * Add prefetch logic to gather op kernel. * Update the indention of gather op kernel code. * Update the gather kernel code for multiple thread. * Remove reference to ealier version of code in gather functor. * Change the framework_lite dep of gather_functor to framework. * Remove mutex guard in gather functor.
This commit is contained in:
parent
9f3d65cfb1
commit
07a91dac54
|
|
@ -1098,7 +1098,7 @@ tf_kernel_library(
|
|||
visibility = [":friends"],
|
||||
deps = [
|
||||
":bounds_check",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core:framework",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ namespace functor {
|
|||
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
|
||||
template <> \
|
||||
int64 GatherFunctor<GPUDevice, T, Index>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, 3>::ConstTensor Tparams, \
|
||||
OpKernelContext* ctx, typename TTypes<T, 3>::ConstTensor Tparams, \
|
||||
typename TTypes<Index>::ConstFlat Tindices, \
|
||||
typename TTypes<T, 3>::Tensor Tout); \
|
||||
extern template struct GatherFunctor<GPUDevice, T, Index>;
|
||||
|
|
|
|||
|
|
@ -23,6 +23,8 @@ limitations under the License.
|
|||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/platform/prefetch.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
|
@ -32,7 +34,8 @@ namespace functor {
|
|||
// Helper method to copy using memcpy.
|
||||
template <typename T, typename Index, typename SliceIndex,
|
||||
SliceIndex static_slice_elems>
|
||||
SliceIndex HandleCopies(typename TTypes<T, 3>::ConstTensor params,
|
||||
SliceIndex HandleCopies(OpKernelContext* ctx,
|
||||
typename TTypes<T, 3>::ConstTensor params,
|
||||
typename TTypes<Index>::ConstFlat indices,
|
||||
SliceIndex slice_elems,
|
||||
typename TTypes<T, 3>::Tensor out) {
|
||||
|
|
@ -47,44 +50,64 @@ SliceIndex HandleCopies(typename TTypes<T, 3>::ConstTensor params,
|
|||
}
|
||||
// Compute slice_bytes here so that static knowledge is available
|
||||
const size_t slice_bytes = slice_elems * sizeof(T);
|
||||
for (SliceIndex b = 0; b < batch_size; b++) {
|
||||
for (SliceIndex i = 0; i < indices_size; i++) {
|
||||
const SliceIndex i_next = i + 1;
|
||||
const SliceIndex b_next = b + 1;
|
||||
if (i_next < indices_size) {
|
||||
port::prefetch<port::PREFETCH_HINT_T0>(¶ms(b, indices(i_next), 0));
|
||||
port::prefetch<port::PREFETCH_HINT_T0>(&out(b, i_next, 0));
|
||||
} else if (b_next < batch_size) {
|
||||
auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
|
||||
mutex mu;
|
||||
// Store the value of invalidate index for printing error information, it's a shared variable.
|
||||
SliceIndex result = -1;
|
||||
auto work = [&] (int64 start, int64 end) {
|
||||
SliceIndex batch_idx = static_cast<SliceIndex>(start / indices_size);
|
||||
SliceIndex indices_idx = static_cast<SliceIndex>(start % indices_size);
|
||||
SliceIndex batch_idx_end = static_cast<SliceIndex>(end / indices_size);
|
||||
SliceIndex indices_idx_end = static_cast<SliceIndex>(end % indices_size);
|
||||
|
||||
while ((batch_idx < batch_idx_end) ||
|
||||
(batch_idx == batch_idx_end && indices_idx < indices_idx_end)) {
|
||||
SliceIndex i_next = indices_idx + 1;
|
||||
SliceIndex b_next = batch_idx + 1;
|
||||
if ((batch_idx == batch_idx_end && i_next < indices_idx_end) ||
|
||||
(i_next < indices_size)) {
|
||||
port::prefetch<port::PREFETCH_HINT_T0>(¶ms(batch_idx, indices(i_next), 0));
|
||||
port::prefetch<port::PREFETCH_HINT_T0>(&out(batch_idx, i_next, 0));
|
||||
b_next = batch_idx;
|
||||
} else if (b_next <= batch_idx_end) {
|
||||
port::prefetch<port::PREFETCH_HINT_T0>(¶ms(b_next, indices(0), 0));
|
||||
port::prefetch<port::PREFETCH_HINT_T0>(&out(b_next, 0, 0));
|
||||
i_next = 0;
|
||||
}
|
||||
const Index index = internal::SubtleMustCopy(indices(indices_idx));
|
||||
if (!FastBoundsCheck(index, limit)) {
|
||||
mutex_lock l(mu);
|
||||
result = indices_idx;
|
||||
return;
|
||||
}
|
||||
// Grab the index and check its validity. An earlier version of the
|
||||
// code checked it and then grabbed it from memory a second time, which
|
||||
// was a security risk since it could have changed in between.
|
||||
const Index index = internal::SubtleMustCopy(indices(i));
|
||||
if (!FastBoundsCheck(index, limit)) return i;
|
||||
// Copy using memcpy if possible, otherwise an Eigen loop
|
||||
// TODO(cwhipkey): avoid linking to framework to get Allocator (to improve
|
||||
// ahead-of-time compilation binary size).
|
||||
if (is_simple_type<T>::value) {
|
||||
// Avoid auto-promotion to Index from SliceIndex by casting.
|
||||
memcpy(out_base + (b * indices_size + i) * slice_elems,
|
||||
params_base + (b * static_cast<SliceIndex>(limit) +
|
||||
memcpy(out_base + (batch_idx * indices_size + indices_idx) * slice_elems,
|
||||
params_base + (batch_idx * static_cast<SliceIndex>(limit) +
|
||||
static_cast<SliceIndex>(index)) *
|
||||
slice_elems,
|
||||
slice_elems,
|
||||
slice_bytes);
|
||||
} else {
|
||||
// For non-"simple" types (e.g. strings).
|
||||
out.template chip<1>(i) = params.template chip<1>(index);
|
||||
out.template chip<1>(indices_idx) = params.template chip<1>(index);
|
||||
}
|
||||
indices_idx = i_next;
|
||||
batch_idx = b_next;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
};
|
||||
|
||||
Shard(worker_threads->num_threads, worker_threads->workers, batch_size*indices_size,
|
||||
slice_elems * sizeof(T), work);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T, typename Index>
|
||||
struct GatherFunctorCPU {
|
||||
int64 operator()(typename TTypes<T, 3>::ConstTensor params,
|
||||
int64 operator()(OpKernelContext* ctx,
|
||||
typename TTypes<T, 3>::ConstTensor params,
|
||||
typename TTypes<Index>::ConstFlat indices,
|
||||
typename TTypes<T, 3>::Tensor out) {
|
||||
const int64 N = indices.size();
|
||||
|
|
@ -94,16 +117,16 @@ struct GatherFunctorCPU {
|
|||
bool use_large = (slice_size > std::numeric_limits<int32>::max() ||
|
||||
params.size() > std::numeric_limits<int32>::max() ||
|
||||
N > std::numeric_limits<int32>::max());
|
||||
#define CALL(elems) \
|
||||
do { \
|
||||
if (use_large) { \
|
||||
bad_i = HandleCopies<T, Index, int64, elems>(params, indices, \
|
||||
slice_size, out); \
|
||||
} else { \
|
||||
const int32 small_slice = static_cast<int32>(slice_size); \
|
||||
bad_i = HandleCopies<T, Index, int32, elems>(params, indices, \
|
||||
small_slice, out); \
|
||||
} \
|
||||
#define CALL(elems) \
|
||||
do { \
|
||||
if (use_large) { \
|
||||
bad_i = HandleCopies<T, Index, int64, elems>(ctx, params, indices, \
|
||||
slice_size, out); \
|
||||
} else { \
|
||||
const int32 small_slice = static_cast<int32>(slice_size); \
|
||||
bad_i = HandleCopies<T, Index, int32, elems>(ctx, params, indices, \
|
||||
small_slice, out); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
if (slice_size == 10)
|
||||
|
|
@ -120,18 +143,18 @@ struct GatherFunctorCPU {
|
|||
|
||||
template <typename Device, typename T, typename Index>
|
||||
struct GatherFunctor {
|
||||
int64 operator()(const Device& d, typename TTypes<T, 3>::ConstTensor params,
|
||||
int64 operator()(OpKernelContext* ctx, typename TTypes<T, 3>::ConstTensor params,
|
||||
typename TTypes<Index>::ConstFlat indices,
|
||||
typename TTypes<T, 3>::Tensor out);
|
||||
};
|
||||
|
||||
template <typename T, typename Index>
|
||||
struct GatherFunctor<CPUDevice, T, Index> {
|
||||
int64 operator()(const CPUDevice& d,
|
||||
int64 operator()(OpKernelContext* ctx,
|
||||
typename TTypes<T, 3>::ConstTensor params,
|
||||
typename TTypes<Index>::ConstFlat indices,
|
||||
typename TTypes<T, 3>::Tensor out) {
|
||||
return GatherFunctorCPU<T, Index>()(params, indices, out);
|
||||
return GatherFunctorCPU<T, Index>()(ctx, params, indices, out);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -72,10 +72,11 @@ __global__ void GatherOpKernel(const T* params, const Index* indices, T* out,
|
|||
namespace functor {
|
||||
template <typename T, typename Index>
|
||||
struct GatherFunctor<GPUDevice, T, Index> {
|
||||
int64 operator()(const GPUDevice& d,
|
||||
int64 operator()(OpKernelContext* ctx,
|
||||
typename TTypes<T, 3>::ConstTensor params,
|
||||
typename TTypes<Index>::ConstFlat indices,
|
||||
typename TTypes<T, 3>::Tensor out) {
|
||||
const GPUDevice& d = ctx->eigen_gpu_device();
|
||||
const int64 out_size = out.size();
|
||||
if (out_size == 0) {
|
||||
// We need a check here since the CPU version does useful error checking
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ class GatherOp : public OpKernel {
|
|||
auto out_flat = out->shaped<T, 3>({outer_size, N, inner_size});
|
||||
|
||||
functor::GatherFunctor<Device, T, Index> functor;
|
||||
int64 bad_i = functor(c->eigen_device<Device>(), params_flat,
|
||||
int64 bad_i = functor(c, params_flat,
|
||||
indices_flat, out_flat);
|
||||
|
||||
OP_REQUIRES(
|
||||
|
|
|
|||
|
|
@ -464,7 +464,7 @@ class ResourceGatherOp : public OpKernel {
|
|||
auto out_flat = out->shaped<T, 3>({1, N, out->NumElements() / N});
|
||||
|
||||
functor::GatherFunctor<Device, T, Index> functor;
|
||||
int64 bad_i = functor(c->eigen_device<Device>(), params_flat,
|
||||
int64 bad_i = functor(c, params_flat,
|
||||
indices_flat, out_flat);
|
||||
|
||||
OP_REQUIRES(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user