make gather cpu kernel to be multiple threads. (#12246)

* Change the gather op to multi-thread.

* Modify the gather kernel of xla compiler in order to be compatible with multi-threads cpu kernel.

* Add prefetch logic to gather op kernel.

* Update the indention of gather op kernel code.

* Update the gather kernel code for multiple thread.

* Remove reference to ealier version of code in gather functor.

* Change the framework_lite dep of gather_functor to framework.

* Remove mutex guard in gather functor.
This commit is contained in:
nolan liu 2017-11-01 14:35:18 +08:00 committed by gunan
parent 9f3d65cfb1
commit 07a91dac54
6 changed files with 63 additions and 39 deletions

View File

@ -1098,7 +1098,7 @@ tf_kernel_library(
visibility = [":friends"],
deps = [
":bounds_check",
"//tensorflow/core:framework_lite",
"//tensorflow/core:framework",
"//third_party/eigen3",
],
)

View File

@ -28,7 +28,7 @@ namespace functor {
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
template <> \
int64 GatherFunctor<GPUDevice, T, Index>::operator()( \
const GPUDevice& d, typename TTypes<T, 3>::ConstTensor Tparams, \
OpKernelContext* ctx, typename TTypes<T, 3>::ConstTensor Tparams, \
typename TTypes<Index>::ConstFlat Tindices, \
typename TTypes<T, 3>::Tensor Tout); \
extern template struct GatherFunctor<GPUDevice, T, Index>;

View File

@ -23,6 +23,8 @@ limitations under the License.
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/prefetch.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@ -32,7 +34,8 @@ namespace functor {
// Helper method to copy using memcpy.
template <typename T, typename Index, typename SliceIndex,
SliceIndex static_slice_elems>
SliceIndex HandleCopies(typename TTypes<T, 3>::ConstTensor params,
SliceIndex HandleCopies(OpKernelContext* ctx,
typename TTypes<T, 3>::ConstTensor params,
typename TTypes<Index>::ConstFlat indices,
SliceIndex slice_elems,
typename TTypes<T, 3>::Tensor out) {
@ -47,44 +50,64 @@ SliceIndex HandleCopies(typename TTypes<T, 3>::ConstTensor params,
}
// Compute slice_bytes here so that static knowledge is available
const size_t slice_bytes = slice_elems * sizeof(T);
for (SliceIndex b = 0; b < batch_size; b++) {
for (SliceIndex i = 0; i < indices_size; i++) {
const SliceIndex i_next = i + 1;
const SliceIndex b_next = b + 1;
if (i_next < indices_size) {
port::prefetch<port::PREFETCH_HINT_T0>(&params(b, indices(i_next), 0));
port::prefetch<port::PREFETCH_HINT_T0>(&out(b, i_next, 0));
} else if (b_next < batch_size) {
auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
mutex mu;
// Store the value of invalidate index for printing error information, it's a shared variable.
SliceIndex result = -1;
auto work = [&] (int64 start, int64 end) {
SliceIndex batch_idx = static_cast<SliceIndex>(start / indices_size);
SliceIndex indices_idx = static_cast<SliceIndex>(start % indices_size);
SliceIndex batch_idx_end = static_cast<SliceIndex>(end / indices_size);
SliceIndex indices_idx_end = static_cast<SliceIndex>(end % indices_size);
while ((batch_idx < batch_idx_end) ||
(batch_idx == batch_idx_end && indices_idx < indices_idx_end)) {
SliceIndex i_next = indices_idx + 1;
SliceIndex b_next = batch_idx + 1;
if ((batch_idx == batch_idx_end && i_next < indices_idx_end) ||
(i_next < indices_size)) {
port::prefetch<port::PREFETCH_HINT_T0>(&params(batch_idx, indices(i_next), 0));
port::prefetch<port::PREFETCH_HINT_T0>(&out(batch_idx, i_next, 0));
b_next = batch_idx;
} else if (b_next <= batch_idx_end) {
port::prefetch<port::PREFETCH_HINT_T0>(&params(b_next, indices(0), 0));
port::prefetch<port::PREFETCH_HINT_T0>(&out(b_next, 0, 0));
i_next = 0;
}
const Index index = internal::SubtleMustCopy(indices(indices_idx));
if (!FastBoundsCheck(index, limit)) {
mutex_lock l(mu);
result = indices_idx;
return;
}
// Grab the index and check its validity. An earlier version of the
// code checked it and then grabbed it from memory a second time, which
// was a security risk since it could have changed in between.
const Index index = internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, limit)) return i;
// Copy using memcpy if possible, otherwise an Eigen loop
// TODO(cwhipkey): avoid linking to framework to get Allocator (to improve
// ahead-of-time compilation binary size).
if (is_simple_type<T>::value) {
// Avoid auto-promotion to Index from SliceIndex by casting.
memcpy(out_base + (b * indices_size + i) * slice_elems,
params_base + (b * static_cast<SliceIndex>(limit) +
memcpy(out_base + (batch_idx * indices_size + indices_idx) * slice_elems,
params_base + (batch_idx * static_cast<SliceIndex>(limit) +
static_cast<SliceIndex>(index)) *
slice_elems,
slice_elems,
slice_bytes);
} else {
// For non-"simple" types (e.g. strings).
out.template chip<1>(i) = params.template chip<1>(index);
out.template chip<1>(indices_idx) = params.template chip<1>(index);
}
indices_idx = i_next;
batch_idx = b_next;
}
}
return -1;
};
Shard(worker_threads->num_threads, worker_threads->workers, batch_size*indices_size,
slice_elems * sizeof(T), work);
return result;
}
template <typename T, typename Index>
struct GatherFunctorCPU {
int64 operator()(typename TTypes<T, 3>::ConstTensor params,
int64 operator()(OpKernelContext* ctx,
typename TTypes<T, 3>::ConstTensor params,
typename TTypes<Index>::ConstFlat indices,
typename TTypes<T, 3>::Tensor out) {
const int64 N = indices.size();
@ -94,16 +117,16 @@ struct GatherFunctorCPU {
bool use_large = (slice_size > std::numeric_limits<int32>::max() ||
params.size() > std::numeric_limits<int32>::max() ||
N > std::numeric_limits<int32>::max());
#define CALL(elems) \
do { \
if (use_large) { \
bad_i = HandleCopies<T, Index, int64, elems>(params, indices, \
slice_size, out); \
} else { \
const int32 small_slice = static_cast<int32>(slice_size); \
bad_i = HandleCopies<T, Index, int32, elems>(params, indices, \
small_slice, out); \
} \
#define CALL(elems) \
do { \
if (use_large) { \
bad_i = HandleCopies<T, Index, int64, elems>(ctx, params, indices, \
slice_size, out); \
} else { \
const int32 small_slice = static_cast<int32>(slice_size); \
bad_i = HandleCopies<T, Index, int32, elems>(ctx, params, indices, \
small_slice, out); \
} \
} while (0)
if (slice_size == 10)
@ -120,18 +143,18 @@ struct GatherFunctorCPU {
template <typename Device, typename T, typename Index>
struct GatherFunctor {
int64 operator()(const Device& d, typename TTypes<T, 3>::ConstTensor params,
int64 operator()(OpKernelContext* ctx, typename TTypes<T, 3>::ConstTensor params,
typename TTypes<Index>::ConstFlat indices,
typename TTypes<T, 3>::Tensor out);
};
template <typename T, typename Index>
struct GatherFunctor<CPUDevice, T, Index> {
int64 operator()(const CPUDevice& d,
int64 operator()(OpKernelContext* ctx,
typename TTypes<T, 3>::ConstTensor params,
typename TTypes<Index>::ConstFlat indices,
typename TTypes<T, 3>::Tensor out) {
return GatherFunctorCPU<T, Index>()(params, indices, out);
return GatherFunctorCPU<T, Index>()(ctx, params, indices, out);
}
};

View File

@ -72,10 +72,11 @@ __global__ void GatherOpKernel(const T* params, const Index* indices, T* out,
namespace functor {
template <typename T, typename Index>
struct GatherFunctor<GPUDevice, T, Index> {
int64 operator()(const GPUDevice& d,
int64 operator()(OpKernelContext* ctx,
typename TTypes<T, 3>::ConstTensor params,
typename TTypes<Index>::ConstFlat indices,
typename TTypes<T, 3>::Tensor out) {
const GPUDevice& d = ctx->eigen_gpu_device();
const int64 out_size = out.size();
if (out_size == 0) {
// We need a check here since the CPU version does useful error checking

View File

@ -106,7 +106,7 @@ class GatherOp : public OpKernel {
auto out_flat = out->shaped<T, 3>({outer_size, N, inner_size});
functor::GatherFunctor<Device, T, Index> functor;
int64 bad_i = functor(c->eigen_device<Device>(), params_flat,
int64 bad_i = functor(c, params_flat,
indices_flat, out_flat);
OP_REQUIRES(

View File

@ -464,7 +464,7 @@ class ResourceGatherOp : public OpKernel {
auto out_flat = out->shaped<T, 3>({1, N, out->NumElements() / N});
functor::GatherFunctor<Device, T, Index> functor;
int64 bad_i = functor(c->eigen_device<Device>(), params_flat,
int64 bad_i = functor(c, params_flat,
indices_flat, out_flat);
OP_REQUIRES(