From e2a163a90561bef0accdd7a0f200f692d85e14c9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 10 Aug 2017 17:55:10 -0700 Subject: [PATCH] Merge code from PR #11940 with internal changes from cl/164796436, and update Python tests to also run on GPU. PiperOrigin-RevId: 164929133 --- tensorflow/core/kernels/BUILD | 4 + tensorflow/core/kernels/dynamic_stitch_op.cc | 143 +++++++++++++++--- .../core/kernels/dynamic_stitch_op_gpu.cu.cc | 81 ++++++++++ .../kernel_tests/dynamic_stitch_op_test.py | 50 +++++- 4 files changed, 251 insertions(+), 27 deletions(-) create mode 100644 tensorflow/core/kernels/dynamic_stitch_op_gpu.cu.cc diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index dd23bcab681..a5e3a5feea2 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1573,6 +1573,10 @@ tf_kernel_library( tf_kernel_library( name = "dynamic_stitch_op", + gpu_srcs = [ + "cuda_device_array.h", + "cuda_device_array_gpu.h", + ], prefix = "dynamic_stitch_op", deps = DYNAMIC_DEPS, ) diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc index f011f34fa8f..99bcd90a4e0 100644 --- a/tensorflow/core/kernels/dynamic_stitch_op.cc +++ b/tensorflow/core/kernels/dynamic_stitch_op.cc @@ -21,8 +21,17 @@ limitations under the License. #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/core/threadpool.h" +#ifdef GOOGLE_CUDA +#include "tensorflow/core/kernels/cuda_device_array.h" +#endif // GOOGLE_CUDA + namespace tensorflow { +typedef Eigen::ThreadPoolDevice CPUDevice; +#ifdef GOOGLE_CUDA +typedef Eigen::GpuDevice GPUDevice; +#endif // GOOGLE_CUDA + template class DynamicStitchOpImplBase : public OpKernel { public: @@ -66,17 +75,24 @@ class DynamicStitchOpImplBase : public OpKernel { void CheckArgsAndAllocateResult(OpKernelContext* c, OpInputList* indices_inputs, OpInputList* data_inputs, int* first_dim_size, + int* data_elements_size, Tensor** result_ptr) { // Find maximum index in the indices vectors OP_REQUIRES_OK(c, c->input_list("indices", indices_inputs)); int32 max_index = -1; + if (data_elements_size) { + *data_elements_size = 0; + } for (const Tensor& indices : *indices_inputs) { if (indices.NumElements() > 0) { Eigen::Tensor m = indices.flat().maximum(); max_index = std::max(m(), max_index); } + if (data_elements_size) { + *data_elements_size += indices.NumElements(); + } } *first_dim_size = max_index + 1; @@ -90,18 +106,19 @@ class DynamicStitchOpImplBase : public OpKernel { const Tensor& data = (*data_inputs)[input_num]; OP_REQUIRES( c, TensorShapeUtils::StartsWith(data.shape(), indices.shape()), - errors::InvalidArgument("data[", input_num, "].shape = ", - data.shape().DebugString(), + errors::InvalidArgument("data[", input_num, + "].shape = ", data.shape().DebugString(), " does not start with indices[", input_num, "].shape = ", indices.shape().DebugString())); OP_REQUIRES( c, input_num == 0 || SameExtraShape(data0, indices0, data, indices), errors::InvalidArgument( "Need data[0].shape[", indices0.dims(), ":] = data[", input_num, - "].shape[", indices.dims(), ":], got data[0].shape = ", - data0.shape().DebugString(), ", data[", input_num, "].shape = ", - data.shape().DebugString(), ", indices[0].shape = ", - indices0.shape().DebugString(), ", indices[", input_num, + "].shape[", indices.dims(), + ":], got data[0].shape = ", data0.shape().DebugString(), + ", data[", input_num, "].shape = ", data.shape().DebugString(), + ", indices[0].shape = ", indices0.shape().DebugString(), + ", indices[", input_num, "].shape = ", indices.shape().DebugString())); } @@ -116,10 +133,90 @@ class DynamicStitchOpImplBase : public OpKernel { } }; -template -class DynamicStitchOpImpl : public DynamicStitchOpImplBase { +#if GOOGLE_CUDA + +template +void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device, + const int32 slice_size, const int32 first_dim_size, + const CudaDeviceArrayStruct& input_indices, + const CudaDeviceArrayStruct& input_ptrs, + T* output); + +template +class DynamicStitchOpGPU : public DynamicStitchOpImplBase { public: - explicit DynamicStitchOpImpl(OpKernelConstruction* c) + explicit DynamicStitchOpGPU(OpKernelConstruction* c) + : DynamicStitchOpImplBase(c, "DynamicStitchOp") {} + + void Compute(OpKernelContext* c) override { + OpInputList indices_inputs; + OpInputList data_inputs; + int first_dim_size; + int data_elements_size; + Tensor* merged = nullptr; + this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs, + &first_dim_size, &data_elements_size, + &merged); + if (!c->status().ok()) { + // Avoid segmentation faults if merged cannot be allocated and an error is + // passed back in the context. + return; + } + + // TODO(jeff): Currently we leave uninitialized any portions of + // merged that aren't covered by an index in indices. What should we do? + if (first_dim_size > 0) { + // because the collision requirements, we have to deal with + // collion first before send data to gpu kernel. + // TODO(ekelsen): Instead of doing a serial scan on the CPU to pick the + // last of duplicated indices, it could instead be done of the GPU + // implicitly using atomics to make sure the last index is the final + // write. + const int slice_size = merged->flat_outer_dims().dimension(1); + CudaDeviceArrayOnHost indices_flat(c, first_dim_size); + CudaDeviceArrayOnHost data_flat(c, data_elements_size); + OP_REQUIRES_OK(c, indices_flat.Init()); + OP_REQUIRES_OK(c, data_flat.Init()); + // initialize the indices_flat (-1 represents missing indices) + for (int i = 0; i < first_dim_size; ++i) { + indices_flat.Set(i, -1); + } + + // data_flat index + int32 idx = 0; + // sum of indices_inputs[i].NumElements() for compute indicies_flat value. + int32 base_size = 0; + for (int i = 0; i < indices_inputs.size(); ++i) { + auto indices_vec = indices_inputs[i].flat(); + auto data_ptr_base = data_inputs[i].template flat().data(); + for (int j = 0; j < indices_vec.size(); ++j) { + // indices_flat's indices represent the indices of output. + // indices_flat's values represent the indices of input_data where the + // data located. + indices_flat.Set(indices_vec(j), base_size + j); + data_flat.Set( + idx, const_cast(reinterpret_cast(data_ptr_base) + + j * slice_size)); + ++idx; + } + base_size += indices_vec.size(); + } + OP_REQUIRES_OK(c, indices_flat.Finalize()); + OP_REQUIRES_OK(c, data_flat.Finalize()); + + auto output = merged->template flat().data(); + DynamicStitchGPUImpl(c->eigen_gpu_device(), slice_size, first_dim_size, + indices_flat.data(), data_flat.data(), output); + } + } +}; + +#endif // GOOGLE_CUDA + +template +class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase { + public: + explicit DynamicStitchOpImplCPU(OpKernelConstruction* c) : DynamicStitchOpImplBase( c, (Parallel ? "ParallelDynamicStitchOp" : "DynamicStitchOp")) {} @@ -129,7 +226,7 @@ class DynamicStitchOpImpl : public DynamicStitchOpImplBase { int first_dim_size; Tensor* merged = nullptr; this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs, - &first_dim_size, &merged); + &first_dim_size, nullptr, &merged); if (!c->status().ok()) { // Avoid segmentation faults if merged cannot be allocated and an error is // passed back in the context. @@ -207,13 +304,13 @@ class DynamicStitchOpImpl : public DynamicStitchOpImplBase { // functionality later. template -struct DynamicStitchOp : DynamicStitchOpImpl { - using DynamicStitchOpImpl::DynamicStitchOpImpl; +struct DynamicStitchOpCPU : DynamicStitchOpImplCPU { + using DynamicStitchOpImplCPU::DynamicStitchOpImplCPU; }; template -struct ParallelDynamicStitchOp : DynamicStitchOpImpl { - using DynamicStitchOpImpl::DynamicStitchOpImpl; +struct ParallelDynamicStitchOpCPU : DynamicStitchOpImplCPU { + using DynamicStitchOpImplCPU::DynamicStitchOpImplCPU; }; #define REGISTER_DYNAMIC_STITCH(type) \ @@ -221,12 +318,12 @@ struct ParallelDynamicStitchOp : DynamicStitchOpImpl { .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .HostMemory("indices"), \ - DynamicStitchOp) \ + DynamicStitchOpCPU) \ REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .HostMemory("indices"), \ - ParallelDynamicStitchOp) + ParallelDynamicStitchOpCPU) TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH); #undef REGISTER_DYNAMIC_STITCH @@ -236,19 +333,21 @@ TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH); REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \ .Device(DEVICE_GPU) \ .TypeConstraint("T") \ - .HostMemory("indices") \ - .HostMemory("data") \ - .HostMemory("merged"), \ - DynamicStitchOp) \ + .HostMemory("indices"), \ + DynamicStitchOpGPU) \ REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \ .Device(DEVICE_GPU) \ .TypeConstraint("T") \ .HostMemory("indices") \ .HostMemory("data") \ .HostMemory("merged"), \ - ParallelDynamicStitchOp) + ParallelDynamicStitchOpCPU) -TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH_GPU); +TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_STITCH_GPU); +TF_CALL_complex64(REGISTER_DYNAMIC_STITCH_GPU); +TF_CALL_complex128(REGISTER_DYNAMIC_STITCH_GPU); +TF_CALL_int64(REGISTER_DYNAMIC_STITCH_GPU); +TF_CALL_int32(REGISTER_DYNAMIC_STITCH_GPU); #undef REGISTER_DYNAMIC_STITCH_GPU #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/dynamic_stitch_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_stitch_op_gpu.cu.cc new file mode 100644 index 00000000000..102cdc40d42 --- /dev/null +++ b/tensorflow/core/kernels/dynamic_stitch_op_gpu.cu.cc @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/cuda_device_array_gpu.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { + +using GPUDevice = Eigen::GpuDevice; + +namespace { + +template +__global__ void DynamicStitchKernel(const int32 slice_size, + const int32 output_size, + CudaDeviceArrayStruct input_indices, + CudaDeviceArrayStruct input_ptrs, + T* output) { + int32* data_indices = GetCudaDeviceArrayOnDevice(&input_indices); + const T** data_ptrs = GetCudaDeviceArrayOnDevice(&input_ptrs); + CUDA_1D_KERNEL_LOOP(output_index, output_size) { + const int32 slice_id = output_index / slice_size; + const int32 slice_offset = output_index % slice_size; + const int32 input_index = data_indices[slice_id]; + if (input_index != -1) { + output[output_index] = ldg(data_ptrs[input_index] + slice_offset); + } + } +} + +} // namespace + +template +void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device, + const int32 slice_size, const int32 first_dim_size, + const CudaDeviceArrayStruct& input_indices, + const CudaDeviceArrayStruct& input_ptrs, + T* output) { + const int32 output_size = first_dim_size * slice_size; + auto config = GetCudaLaunchConfig(output_size, gpu_device); + + DynamicStitchKernel + <<>>( + slice_size, output_size, input_indices, input_ptrs, output); +} + +#define REGISTER_GPU(T) \ + template void DynamicStitchGPUImpl( \ + const Eigen::GpuDevice& gpu_device, const int32 slice_size, \ + const int32 first_dim_size, \ + const CudaDeviceArrayStruct& input_indices, \ + const CudaDeviceArrayStruct& input_ptrs, T* output); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); +TF_CALL_complex64(REGISTER_GPU); +TF_CALL_complex128(REGISTER_GPU); +TF_CALL_int64(REGISTER_GPU); +TF_CALL_int32(REGISTER_GPU) + +#undef REGISTER_GPU + +} // namespace tensorflow +#endif // GOOGLE_CUDA diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py index f894da51571..b4a5e1f4221 100644 --- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py +++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gradients_impl import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import @@ -33,7 +34,7 @@ class DynamicStitchTestBase(object): self.stitch_op = stitch_op def testScalar(self): - with self.test_session(): + with self.test_session(use_gpu=True): indices = [constant_op.constant(0), constant_op.constant(1)] data = [constant_op.constant(40), constant_op.constant(60)] for step in -1, 1: @@ -46,7 +47,7 @@ class DynamicStitchTestBase(object): self.assertEqual([None], stitched_t.get_shape().as_list()) def testSimpleOneDimensional(self): - with self.test_session(): + with self.test_session(use_gpu=True): indices = [ constant_op.constant([0, 4, 7]), constant_op.constant([1, 6, 2, 3, 5]) ] @@ -63,7 +64,7 @@ class DynamicStitchTestBase(object): self.assertEqual([None], stitched_t.get_shape().as_list()) def testOneListOneDimensional(self): - with self.test_session(): + with self.test_session(use_gpu=True): indices = [constant_op.constant([1, 6, 2, 3, 5, 0, 4, 7])] data = [constant_op.constant([10, 60, 20, 30, 50, 0, 40, 70])] stitched_t = self.stitch_op(indices, data) @@ -75,7 +76,7 @@ class DynamicStitchTestBase(object): self.assertEqual([None], stitched_t.get_shape().as_list()) def testSimpleTwoDimensional(self): - with self.test_session(): + with self.test_session(use_gpu=True): indices = [ constant_op.constant([0, 4, 7]), constant_op.constant([1, 6]), constant_op.constant([2, 3, 5]) @@ -95,7 +96,7 @@ class DynamicStitchTestBase(object): self.assertEqual([None, 2], stitched_t.get_shape().as_list()) def testHigherRank(self): - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: indices = [ constant_op.constant(6), constant_op.constant([4, 1]), constant_op.constant([[5, 2], [0, 3]]) @@ -176,6 +177,45 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase): test.TestCase.__init__(self, *test_case_args) DynamicStitchTestBase.__init__(self, data_flow_ops.parallel_dynamic_stitch) + def testScalar(self): + with self.test_session(use_gpu=True): + indices = [constant_op.constant(0), constant_op.constant(1)] + data = [constant_op.constant(40.0), constant_op.constant(60.0)] + for step in -1, 1: + stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data) + stitched_val = stitched_t.eval() + self.assertAllEqual([40.0, 60.0][::step], stitched_val) + # Dimension 0 is determined by the max index in indices, so we + # can only infer that the output is a vector of some unknown + # length. + self.assertEqual([None], stitched_t.get_shape().as_list()) + + def testHigherRank(self): + with self.test_session(use_gpu=True) as sess: + indices = [ + constant_op.constant(6), + constant_op.constant([4, 1]), + constant_op.constant([[5, 2], [0, 3]]) + ] + data = [ + constant_op.constant([61, 62], dtype=dtypes.float32), + constant_op.constant([[41, 42], [11, 12]], dtype=dtypes.float32), + constant_op.constant( + [[[51, 52], [21, 22]], [[1, 2], [31, 32]]], dtype=dtypes.float32) + ] + stitched_t = data_flow_ops.dynamic_stitch(indices, data) + stitched_val = stitched_t.eval() + correct = 10 * np.arange(7)[:, None] + [1.0, 2.0] + self.assertAllEqual(correct, stitched_val) + self.assertEqual([None, 2], stitched_t.get_shape().as_list()) + # Test gradients + stitched_grad = 7 * stitched_val + grads = gradients_impl.gradients(stitched_t, indices + data, + stitched_grad) + self.assertEqual(grads[:3], [None] * 3) # Indices have no gradients + for datum, grad in zip(data, sess.run(grads[3:])): + self.assertAllEqual(7.0 * datum.eval(), grad) + if __name__ == "__main__": test.main()