Merge code from PR #11940 with internal changes from cl/164796436, and update Python tests to also run on GPU.

PiperOrigin-RevId: 164929133
This commit is contained in:
A. Unique TensorFlower 2017-08-10 17:55:10 -07:00 committed by TensorFlower Gardener
parent 9fba8c1851
commit e2a163a905
4 changed files with 251 additions and 27 deletions

View File

@ -1573,6 +1573,10 @@ tf_kernel_library(
tf_kernel_library( tf_kernel_library(
name = "dynamic_stitch_op", name = "dynamic_stitch_op",
gpu_srcs = [
"cuda_device_array.h",
"cuda_device_array_gpu.h",
],
prefix = "dynamic_stitch_op", prefix = "dynamic_stitch_op",
deps = DYNAMIC_DEPS, deps = DYNAMIC_DEPS,
) )

View File

@ -21,8 +21,17 @@ limitations under the License.
#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/core/threadpool.h"
#ifdef GOOGLE_CUDA
#include "tensorflow/core/kernels/cuda_device_array.h"
#endif // GOOGLE_CUDA
namespace tensorflow { namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
#ifdef GOOGLE_CUDA
typedef Eigen::GpuDevice GPUDevice;
#endif // GOOGLE_CUDA
template <class T> template <class T>
class DynamicStitchOpImplBase : public OpKernel { class DynamicStitchOpImplBase : public OpKernel {
public: public:
@ -66,17 +75,24 @@ class DynamicStitchOpImplBase : public OpKernel {
void CheckArgsAndAllocateResult(OpKernelContext* c, void CheckArgsAndAllocateResult(OpKernelContext* c,
OpInputList* indices_inputs, OpInputList* indices_inputs,
OpInputList* data_inputs, int* first_dim_size, OpInputList* data_inputs, int* first_dim_size,
int* data_elements_size,
Tensor** result_ptr) { Tensor** result_ptr) {
// Find maximum index in the indices vectors // Find maximum index in the indices vectors
OP_REQUIRES_OK(c, c->input_list("indices", indices_inputs)); OP_REQUIRES_OK(c, c->input_list("indices", indices_inputs));
int32 max_index = -1; int32 max_index = -1;
if (data_elements_size) {
*data_elements_size = 0;
}
for (const Tensor& indices : *indices_inputs) { for (const Tensor& indices : *indices_inputs) {
if (indices.NumElements() > 0) { if (indices.NumElements() > 0) {
Eigen::Tensor<int32, 0, Eigen::RowMajor> m = Eigen::Tensor<int32, 0, Eigen::RowMajor> m =
indices.flat<int32>().maximum(); indices.flat<int32>().maximum();
max_index = std::max(m(), max_index); max_index = std::max(m(), max_index);
} }
if (data_elements_size) {
*data_elements_size += indices.NumElements();
}
} }
*first_dim_size = max_index + 1; *first_dim_size = max_index + 1;
@ -90,18 +106,19 @@ class DynamicStitchOpImplBase : public OpKernel {
const Tensor& data = (*data_inputs)[input_num]; const Tensor& data = (*data_inputs)[input_num];
OP_REQUIRES( OP_REQUIRES(
c, TensorShapeUtils::StartsWith(data.shape(), indices.shape()), c, TensorShapeUtils::StartsWith(data.shape(), indices.shape()),
errors::InvalidArgument("data[", input_num, "].shape = ", errors::InvalidArgument("data[", input_num,
data.shape().DebugString(), "].shape = ", data.shape().DebugString(),
" does not start with indices[", input_num, " does not start with indices[", input_num,
"].shape = ", indices.shape().DebugString())); "].shape = ", indices.shape().DebugString()));
OP_REQUIRES( OP_REQUIRES(
c, input_num == 0 || SameExtraShape(data0, indices0, data, indices), c, input_num == 0 || SameExtraShape(data0, indices0, data, indices),
errors::InvalidArgument( errors::InvalidArgument(
"Need data[0].shape[", indices0.dims(), ":] = data[", input_num, "Need data[0].shape[", indices0.dims(), ":] = data[", input_num,
"].shape[", indices.dims(), ":], got data[0].shape = ", "].shape[", indices.dims(),
data0.shape().DebugString(), ", data[", input_num, "].shape = ", ":], got data[0].shape = ", data0.shape().DebugString(),
data.shape().DebugString(), ", indices[0].shape = ", ", data[", input_num, "].shape = ", data.shape().DebugString(),
indices0.shape().DebugString(), ", indices[", input_num, ", indices[0].shape = ", indices0.shape().DebugString(),
", indices[", input_num,
"].shape = ", indices.shape().DebugString())); "].shape = ", indices.shape().DebugString()));
} }
@ -116,10 +133,90 @@ class DynamicStitchOpImplBase : public OpKernel {
} }
}; };
template <class T, bool Parallel> #if GOOGLE_CUDA
class DynamicStitchOpImpl : public DynamicStitchOpImplBase<T> {
template <typename T>
void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
const int32 slice_size, const int32 first_dim_size,
const CudaDeviceArrayStruct<int>& input_indices,
const CudaDeviceArrayStruct<const T*>& input_ptrs,
T* output);
template <class T>
class DynamicStitchOpGPU : public DynamicStitchOpImplBase<T> {
public: public:
explicit DynamicStitchOpImpl(OpKernelConstruction* c) explicit DynamicStitchOpGPU(OpKernelConstruction* c)
: DynamicStitchOpImplBase<T>(c, "DynamicStitchOp") {}
void Compute(OpKernelContext* c) override {
OpInputList indices_inputs;
OpInputList data_inputs;
int first_dim_size;
int data_elements_size;
Tensor* merged = nullptr;
this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs,
&first_dim_size, &data_elements_size,
&merged);
if (!c->status().ok()) {
// Avoid segmentation faults if merged cannot be allocated and an error is
// passed back in the context.
return;
}
// TODO(jeff): Currently we leave uninitialized any portions of
// merged that aren't covered by an index in indices. What should we do?
if (first_dim_size > 0) {
// because the collision requirements, we have to deal with
// collion first before send data to gpu kernel.
// TODO(ekelsen): Instead of doing a serial scan on the CPU to pick the
// last of duplicated indices, it could instead be done of the GPU
// implicitly using atomics to make sure the last index is the final
// write.
const int slice_size = merged->flat_outer_dims<T>().dimension(1);
CudaDeviceArrayOnHost<int32> indices_flat(c, first_dim_size);
CudaDeviceArrayOnHost<const T*> data_flat(c, data_elements_size);
OP_REQUIRES_OK(c, indices_flat.Init());
OP_REQUIRES_OK(c, data_flat.Init());
// initialize the indices_flat (-1 represents missing indices)
for (int i = 0; i < first_dim_size; ++i) {
indices_flat.Set(i, -1);
}
// data_flat index
int32 idx = 0;
// sum of indices_inputs[i].NumElements() for compute indicies_flat value.
int32 base_size = 0;
for (int i = 0; i < indices_inputs.size(); ++i) {
auto indices_vec = indices_inputs[i].flat<int32>();
auto data_ptr_base = data_inputs[i].template flat<T>().data();
for (int j = 0; j < indices_vec.size(); ++j) {
// indices_flat's indices represent the indices of output.
// indices_flat's values represent the indices of input_data where the
// data located.
indices_flat.Set(indices_vec(j), base_size + j);
data_flat.Set(
idx, const_cast<T*>(reinterpret_cast<const T*>(data_ptr_base) +
j * slice_size));
++idx;
}
base_size += indices_vec.size();
}
OP_REQUIRES_OK(c, indices_flat.Finalize());
OP_REQUIRES_OK(c, data_flat.Finalize());
auto output = merged->template flat<T>().data();
DynamicStitchGPUImpl<T>(c->eigen_gpu_device(), slice_size, first_dim_size,
indices_flat.data(), data_flat.data(), output);
}
}
};
#endif // GOOGLE_CUDA
template <class T, bool Parallel>
class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> {
public:
explicit DynamicStitchOpImplCPU(OpKernelConstruction* c)
: DynamicStitchOpImplBase<T>( : DynamicStitchOpImplBase<T>(
c, (Parallel ? "ParallelDynamicStitchOp" : "DynamicStitchOp")) {} c, (Parallel ? "ParallelDynamicStitchOp" : "DynamicStitchOp")) {}
@ -129,7 +226,7 @@ class DynamicStitchOpImpl : public DynamicStitchOpImplBase<T> {
int first_dim_size; int first_dim_size;
Tensor* merged = nullptr; Tensor* merged = nullptr;
this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs, this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs,
&first_dim_size, &merged); &first_dim_size, nullptr, &merged);
if (!c->status().ok()) { if (!c->status().ok()) {
// Avoid segmentation faults if merged cannot be allocated and an error is // Avoid segmentation faults if merged cannot be allocated and an error is
// passed back in the context. // passed back in the context.
@ -207,13 +304,13 @@ class DynamicStitchOpImpl : public DynamicStitchOpImplBase<T> {
// functionality later. // functionality later.
template <typename T> template <typename T>
struct DynamicStitchOp : DynamicStitchOpImpl<T, false> { struct DynamicStitchOpCPU : DynamicStitchOpImplCPU<T, false> {
using DynamicStitchOpImpl<T, false>::DynamicStitchOpImpl; using DynamicStitchOpImplCPU<T, false>::DynamicStitchOpImplCPU;
}; };
template <typename T> template <typename T>
struct ParallelDynamicStitchOp : DynamicStitchOpImpl<T, true> { struct ParallelDynamicStitchOpCPU : DynamicStitchOpImplCPU<T, true> {
using DynamicStitchOpImpl<T, true>::DynamicStitchOpImpl; using DynamicStitchOpImplCPU<T, true>::DynamicStitchOpImplCPU;
}; };
#define REGISTER_DYNAMIC_STITCH(type) \ #define REGISTER_DYNAMIC_STITCH(type) \
@ -221,12 +318,12 @@ struct ParallelDynamicStitchOp : DynamicStitchOpImpl<T, true> {
.Device(DEVICE_CPU) \ .Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \ .TypeConstraint<type>("T") \
.HostMemory("indices"), \ .HostMemory("indices"), \
DynamicStitchOp<type>) \ DynamicStitchOpCPU<type>) \
REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \ REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \
.Device(DEVICE_CPU) \ .Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \ .TypeConstraint<type>("T") \
.HostMemory("indices"), \ .HostMemory("indices"), \
ParallelDynamicStitchOp<type>) ParallelDynamicStitchOpCPU<type>)
TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH); TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
#undef REGISTER_DYNAMIC_STITCH #undef REGISTER_DYNAMIC_STITCH
@ -236,19 +333,21 @@ TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \ REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
.Device(DEVICE_GPU) \ .Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \ .TypeConstraint<type>("T") \
.HostMemory("indices") \ .HostMemory("indices"), \
.HostMemory("data") \ DynamicStitchOpGPU<type>) \
.HostMemory("merged"), \
DynamicStitchOp<type>) \
REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \ REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \
.Device(DEVICE_GPU) \ .Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \ .TypeConstraint<type>("T") \
.HostMemory("indices") \ .HostMemory("indices") \
.HostMemory("data") \ .HostMemory("data") \
.HostMemory("merged"), \ .HostMemory("merged"), \
ParallelDynamicStitchOp<type>) ParallelDynamicStitchOpCPU<type>)
TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH_GPU); TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
TF_CALL_complex64(REGISTER_DYNAMIC_STITCH_GPU);
TF_CALL_complex128(REGISTER_DYNAMIC_STITCH_GPU);
TF_CALL_int64(REGISTER_DYNAMIC_STITCH_GPU);
TF_CALL_int32(REGISTER_DYNAMIC_STITCH_GPU);
#undef REGISTER_DYNAMIC_STITCH_GPU #undef REGISTER_DYNAMIC_STITCH_GPU
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA

View File

@ -0,0 +1,81 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/cuda_device_array_gpu.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow {
using GPUDevice = Eigen::GpuDevice;
namespace {
template <typename T>
__global__ void DynamicStitchKernel(const int32 slice_size,
const int32 output_size,
CudaDeviceArrayStruct<int32> input_indices,
CudaDeviceArrayStruct<const T*> input_ptrs,
T* output) {
int32* data_indices = GetCudaDeviceArrayOnDevice(&input_indices);
const T** data_ptrs = GetCudaDeviceArrayOnDevice(&input_ptrs);
CUDA_1D_KERNEL_LOOP(output_index, output_size) {
const int32 slice_id = output_index / slice_size;
const int32 slice_offset = output_index % slice_size;
const int32 input_index = data_indices[slice_id];
if (input_index != -1) {
output[output_index] = ldg(data_ptrs[input_index] + slice_offset);
}
}
}
} // namespace
template <typename T>
void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
const int32 slice_size, const int32 first_dim_size,
const CudaDeviceArrayStruct<int>& input_indices,
const CudaDeviceArrayStruct<const T*>& input_ptrs,
T* output) {
const int32 output_size = first_dim_size * slice_size;
auto config = GetCudaLaunchConfig(output_size, gpu_device);
DynamicStitchKernel<T>
<<<config.block_count, config.thread_per_block, 0, gpu_device.stream()>>>(
slice_size, output_size, input_indices, input_ptrs, output);
}
#define REGISTER_GPU(T) \
template void DynamicStitchGPUImpl( \
const Eigen::GpuDevice& gpu_device, const int32 slice_size, \
const int32 first_dim_size, \
const CudaDeviceArrayStruct<int32>& input_indices, \
const CudaDeviceArrayStruct<const T*>& input_ptrs, T* output);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
TF_CALL_int64(REGISTER_GPU);
TF_CALL_int32(REGISTER_GPU)
#undef REGISTER_GPU
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import gradients_impl
import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import
@ -33,7 +34,7 @@ class DynamicStitchTestBase(object):
self.stitch_op = stitch_op self.stitch_op = stitch_op
def testScalar(self): def testScalar(self):
with self.test_session(): with self.test_session(use_gpu=True):
indices = [constant_op.constant(0), constant_op.constant(1)] indices = [constant_op.constant(0), constant_op.constant(1)]
data = [constant_op.constant(40), constant_op.constant(60)] data = [constant_op.constant(40), constant_op.constant(60)]
for step in -1, 1: for step in -1, 1:
@ -46,7 +47,7 @@ class DynamicStitchTestBase(object):
self.assertEqual([None], stitched_t.get_shape().as_list()) self.assertEqual([None], stitched_t.get_shape().as_list())
def testSimpleOneDimensional(self): def testSimpleOneDimensional(self):
with self.test_session(): with self.test_session(use_gpu=True):
indices = [ indices = [
constant_op.constant([0, 4, 7]), constant_op.constant([1, 6, 2, 3, 5]) constant_op.constant([0, 4, 7]), constant_op.constant([1, 6, 2, 3, 5])
] ]
@ -63,7 +64,7 @@ class DynamicStitchTestBase(object):
self.assertEqual([None], stitched_t.get_shape().as_list()) self.assertEqual([None], stitched_t.get_shape().as_list())
def testOneListOneDimensional(self): def testOneListOneDimensional(self):
with self.test_session(): with self.test_session(use_gpu=True):
indices = [constant_op.constant([1, 6, 2, 3, 5, 0, 4, 7])] indices = [constant_op.constant([1, 6, 2, 3, 5, 0, 4, 7])]
data = [constant_op.constant([10, 60, 20, 30, 50, 0, 40, 70])] data = [constant_op.constant([10, 60, 20, 30, 50, 0, 40, 70])]
stitched_t = self.stitch_op(indices, data) stitched_t = self.stitch_op(indices, data)
@ -75,7 +76,7 @@ class DynamicStitchTestBase(object):
self.assertEqual([None], stitched_t.get_shape().as_list()) self.assertEqual([None], stitched_t.get_shape().as_list())
def testSimpleTwoDimensional(self): def testSimpleTwoDimensional(self):
with self.test_session(): with self.test_session(use_gpu=True):
indices = [ indices = [
constant_op.constant([0, 4, 7]), constant_op.constant([1, 6]), constant_op.constant([0, 4, 7]), constant_op.constant([1, 6]),
constant_op.constant([2, 3, 5]) constant_op.constant([2, 3, 5])
@ -95,7 +96,7 @@ class DynamicStitchTestBase(object):
self.assertEqual([None, 2], stitched_t.get_shape().as_list()) self.assertEqual([None, 2], stitched_t.get_shape().as_list())
def testHigherRank(self): def testHigherRank(self):
with self.test_session() as sess: with self.test_session(use_gpu=True) as sess:
indices = [ indices = [
constant_op.constant(6), constant_op.constant([4, 1]), constant_op.constant(6), constant_op.constant([4, 1]),
constant_op.constant([[5, 2], [0, 3]]) constant_op.constant([[5, 2], [0, 3]])
@ -176,6 +177,45 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
test.TestCase.__init__(self, *test_case_args) test.TestCase.__init__(self, *test_case_args)
DynamicStitchTestBase.__init__(self, data_flow_ops.parallel_dynamic_stitch) DynamicStitchTestBase.__init__(self, data_flow_ops.parallel_dynamic_stitch)
def testScalar(self):
with self.test_session(use_gpu=True):
indices = [constant_op.constant(0), constant_op.constant(1)]
data = [constant_op.constant(40.0), constant_op.constant(60.0)]
for step in -1, 1:
stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data)
stitched_val = stitched_t.eval()
self.assertAllEqual([40.0, 60.0][::step], stitched_val)
# Dimension 0 is determined by the max index in indices, so we
# can only infer that the output is a vector of some unknown
# length.
self.assertEqual([None], stitched_t.get_shape().as_list())
def testHigherRank(self):
with self.test_session(use_gpu=True) as sess:
indices = [
constant_op.constant(6),
constant_op.constant([4, 1]),
constant_op.constant([[5, 2], [0, 3]])
]
data = [
constant_op.constant([61, 62], dtype=dtypes.float32),
constant_op.constant([[41, 42], [11, 12]], dtype=dtypes.float32),
constant_op.constant(
[[[51, 52], [21, 22]], [[1, 2], [31, 32]]], dtype=dtypes.float32)
]
stitched_t = data_flow_ops.dynamic_stitch(indices, data)
stitched_val = stitched_t.eval()
correct = 10 * np.arange(7)[:, None] + [1.0, 2.0]
self.assertAllEqual(correct, stitched_val)
self.assertEqual([None, 2], stitched_t.get_shape().as_list())
# Test gradients
stitched_grad = 7 * stitched_val
grads = gradients_impl.gradients(stitched_t, indices + data,
stitched_grad)
self.assertEqual(grads[:3], [None] * 3) # Indices have no gradients
for datum, grad in zip(data, sess.run(grads[3:])):
self.assertAllEqual(7.0 * datum.eval(), grad)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()