mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Merge code from PR #11940 with internal changes from cl/164796436, and update Python tests to also run on GPU.
PiperOrigin-RevId: 164929133
This commit is contained in:
parent
9fba8c1851
commit
e2a163a905
|
|
@ -1573,6 +1573,10 @@ tf_kernel_library(
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "dynamic_stitch_op",
|
name = "dynamic_stitch_op",
|
||||||
|
gpu_srcs = [
|
||||||
|
"cuda_device_array.h",
|
||||||
|
"cuda_device_array_gpu.h",
|
||||||
|
],
|
||||||
prefix = "dynamic_stitch_op",
|
prefix = "dynamic_stitch_op",
|
||||||
deps = DYNAMIC_DEPS,
|
deps = DYNAMIC_DEPS,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -21,8 +21,17 @@ limitations under the License.
|
||||||
#include "tensorflow/core/kernels/bounds_check.h"
|
#include "tensorflow/core/kernels/bounds_check.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
|
|
||||||
|
#ifdef GOOGLE_CUDA
|
||||||
|
#include "tensorflow/core/kernels/cuda_device_array.h"
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
#ifdef GOOGLE_CUDA
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
class DynamicStitchOpImplBase : public OpKernel {
|
class DynamicStitchOpImplBase : public OpKernel {
|
||||||
public:
|
public:
|
||||||
|
|
@ -66,17 +75,24 @@ class DynamicStitchOpImplBase : public OpKernel {
|
||||||
void CheckArgsAndAllocateResult(OpKernelContext* c,
|
void CheckArgsAndAllocateResult(OpKernelContext* c,
|
||||||
OpInputList* indices_inputs,
|
OpInputList* indices_inputs,
|
||||||
OpInputList* data_inputs, int* first_dim_size,
|
OpInputList* data_inputs, int* first_dim_size,
|
||||||
|
int* data_elements_size,
|
||||||
Tensor** result_ptr) {
|
Tensor** result_ptr) {
|
||||||
// Find maximum index in the indices vectors
|
// Find maximum index in the indices vectors
|
||||||
OP_REQUIRES_OK(c, c->input_list("indices", indices_inputs));
|
OP_REQUIRES_OK(c, c->input_list("indices", indices_inputs));
|
||||||
|
|
||||||
int32 max_index = -1;
|
int32 max_index = -1;
|
||||||
|
if (data_elements_size) {
|
||||||
|
*data_elements_size = 0;
|
||||||
|
}
|
||||||
for (const Tensor& indices : *indices_inputs) {
|
for (const Tensor& indices : *indices_inputs) {
|
||||||
if (indices.NumElements() > 0) {
|
if (indices.NumElements() > 0) {
|
||||||
Eigen::Tensor<int32, 0, Eigen::RowMajor> m =
|
Eigen::Tensor<int32, 0, Eigen::RowMajor> m =
|
||||||
indices.flat<int32>().maximum();
|
indices.flat<int32>().maximum();
|
||||||
max_index = std::max(m(), max_index);
|
max_index = std::max(m(), max_index);
|
||||||
}
|
}
|
||||||
|
if (data_elements_size) {
|
||||||
|
*data_elements_size += indices.NumElements();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
*first_dim_size = max_index + 1;
|
*first_dim_size = max_index + 1;
|
||||||
|
|
@ -90,18 +106,19 @@ class DynamicStitchOpImplBase : public OpKernel {
|
||||||
const Tensor& data = (*data_inputs)[input_num];
|
const Tensor& data = (*data_inputs)[input_num];
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
c, TensorShapeUtils::StartsWith(data.shape(), indices.shape()),
|
c, TensorShapeUtils::StartsWith(data.shape(), indices.shape()),
|
||||||
errors::InvalidArgument("data[", input_num, "].shape = ",
|
errors::InvalidArgument("data[", input_num,
|
||||||
data.shape().DebugString(),
|
"].shape = ", data.shape().DebugString(),
|
||||||
" does not start with indices[", input_num,
|
" does not start with indices[", input_num,
|
||||||
"].shape = ", indices.shape().DebugString()));
|
"].shape = ", indices.shape().DebugString()));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
c, input_num == 0 || SameExtraShape(data0, indices0, data, indices),
|
c, input_num == 0 || SameExtraShape(data0, indices0, data, indices),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Need data[0].shape[", indices0.dims(), ":] = data[", input_num,
|
"Need data[0].shape[", indices0.dims(), ":] = data[", input_num,
|
||||||
"].shape[", indices.dims(), ":], got data[0].shape = ",
|
"].shape[", indices.dims(),
|
||||||
data0.shape().DebugString(), ", data[", input_num, "].shape = ",
|
":], got data[0].shape = ", data0.shape().DebugString(),
|
||||||
data.shape().DebugString(), ", indices[0].shape = ",
|
", data[", input_num, "].shape = ", data.shape().DebugString(),
|
||||||
indices0.shape().DebugString(), ", indices[", input_num,
|
", indices[0].shape = ", indices0.shape().DebugString(),
|
||||||
|
", indices[", input_num,
|
||||||
"].shape = ", indices.shape().DebugString()));
|
"].shape = ", indices.shape().DebugString()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -116,10 +133,90 @@ class DynamicStitchOpImplBase : public OpKernel {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class T, bool Parallel>
|
#if GOOGLE_CUDA
|
||||||
class DynamicStitchOpImpl : public DynamicStitchOpImplBase<T> {
|
|
||||||
|
template <typename T>
|
||||||
|
void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
|
||||||
|
const int32 slice_size, const int32 first_dim_size,
|
||||||
|
const CudaDeviceArrayStruct<int>& input_indices,
|
||||||
|
const CudaDeviceArrayStruct<const T*>& input_ptrs,
|
||||||
|
T* output);
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
class DynamicStitchOpGPU : public DynamicStitchOpImplBase<T> {
|
||||||
public:
|
public:
|
||||||
explicit DynamicStitchOpImpl(OpKernelConstruction* c)
|
explicit DynamicStitchOpGPU(OpKernelConstruction* c)
|
||||||
|
: DynamicStitchOpImplBase<T>(c, "DynamicStitchOp") {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* c) override {
|
||||||
|
OpInputList indices_inputs;
|
||||||
|
OpInputList data_inputs;
|
||||||
|
int first_dim_size;
|
||||||
|
int data_elements_size;
|
||||||
|
Tensor* merged = nullptr;
|
||||||
|
this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs,
|
||||||
|
&first_dim_size, &data_elements_size,
|
||||||
|
&merged);
|
||||||
|
if (!c->status().ok()) {
|
||||||
|
// Avoid segmentation faults if merged cannot be allocated and an error is
|
||||||
|
// passed back in the context.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(jeff): Currently we leave uninitialized any portions of
|
||||||
|
// merged that aren't covered by an index in indices. What should we do?
|
||||||
|
if (first_dim_size > 0) {
|
||||||
|
// because the collision requirements, we have to deal with
|
||||||
|
// collion first before send data to gpu kernel.
|
||||||
|
// TODO(ekelsen): Instead of doing a serial scan on the CPU to pick the
|
||||||
|
// last of duplicated indices, it could instead be done of the GPU
|
||||||
|
// implicitly using atomics to make sure the last index is the final
|
||||||
|
// write.
|
||||||
|
const int slice_size = merged->flat_outer_dims<T>().dimension(1);
|
||||||
|
CudaDeviceArrayOnHost<int32> indices_flat(c, first_dim_size);
|
||||||
|
CudaDeviceArrayOnHost<const T*> data_flat(c, data_elements_size);
|
||||||
|
OP_REQUIRES_OK(c, indices_flat.Init());
|
||||||
|
OP_REQUIRES_OK(c, data_flat.Init());
|
||||||
|
// initialize the indices_flat (-1 represents missing indices)
|
||||||
|
for (int i = 0; i < first_dim_size; ++i) {
|
||||||
|
indices_flat.Set(i, -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// data_flat index
|
||||||
|
int32 idx = 0;
|
||||||
|
// sum of indices_inputs[i].NumElements() for compute indicies_flat value.
|
||||||
|
int32 base_size = 0;
|
||||||
|
for (int i = 0; i < indices_inputs.size(); ++i) {
|
||||||
|
auto indices_vec = indices_inputs[i].flat<int32>();
|
||||||
|
auto data_ptr_base = data_inputs[i].template flat<T>().data();
|
||||||
|
for (int j = 0; j < indices_vec.size(); ++j) {
|
||||||
|
// indices_flat's indices represent the indices of output.
|
||||||
|
// indices_flat's values represent the indices of input_data where the
|
||||||
|
// data located.
|
||||||
|
indices_flat.Set(indices_vec(j), base_size + j);
|
||||||
|
data_flat.Set(
|
||||||
|
idx, const_cast<T*>(reinterpret_cast<const T*>(data_ptr_base) +
|
||||||
|
j * slice_size));
|
||||||
|
++idx;
|
||||||
|
}
|
||||||
|
base_size += indices_vec.size();
|
||||||
|
}
|
||||||
|
OP_REQUIRES_OK(c, indices_flat.Finalize());
|
||||||
|
OP_REQUIRES_OK(c, data_flat.Finalize());
|
||||||
|
|
||||||
|
auto output = merged->template flat<T>().data();
|
||||||
|
DynamicStitchGPUImpl<T>(c->eigen_gpu_device(), slice_size, first_dim_size,
|
||||||
|
indices_flat.data(), data_flat.data(), output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
template <class T, bool Parallel>
|
||||||
|
class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> {
|
||||||
|
public:
|
||||||
|
explicit DynamicStitchOpImplCPU(OpKernelConstruction* c)
|
||||||
: DynamicStitchOpImplBase<T>(
|
: DynamicStitchOpImplBase<T>(
|
||||||
c, (Parallel ? "ParallelDynamicStitchOp" : "DynamicStitchOp")) {}
|
c, (Parallel ? "ParallelDynamicStitchOp" : "DynamicStitchOp")) {}
|
||||||
|
|
||||||
|
|
@ -129,7 +226,7 @@ class DynamicStitchOpImpl : public DynamicStitchOpImplBase<T> {
|
||||||
int first_dim_size;
|
int first_dim_size;
|
||||||
Tensor* merged = nullptr;
|
Tensor* merged = nullptr;
|
||||||
this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs,
|
this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs,
|
||||||
&first_dim_size, &merged);
|
&first_dim_size, nullptr, &merged);
|
||||||
if (!c->status().ok()) {
|
if (!c->status().ok()) {
|
||||||
// Avoid segmentation faults if merged cannot be allocated and an error is
|
// Avoid segmentation faults if merged cannot be allocated and an error is
|
||||||
// passed back in the context.
|
// passed back in the context.
|
||||||
|
|
@ -207,13 +304,13 @@ class DynamicStitchOpImpl : public DynamicStitchOpImplBase<T> {
|
||||||
// functionality later.
|
// functionality later.
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct DynamicStitchOp : DynamicStitchOpImpl<T, false> {
|
struct DynamicStitchOpCPU : DynamicStitchOpImplCPU<T, false> {
|
||||||
using DynamicStitchOpImpl<T, false>::DynamicStitchOpImpl;
|
using DynamicStitchOpImplCPU<T, false>::DynamicStitchOpImplCPU;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ParallelDynamicStitchOp : DynamicStitchOpImpl<T, true> {
|
struct ParallelDynamicStitchOpCPU : DynamicStitchOpImplCPU<T, true> {
|
||||||
using DynamicStitchOpImpl<T, true>::DynamicStitchOpImpl;
|
using DynamicStitchOpImplCPU<T, true>::DynamicStitchOpImplCPU;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define REGISTER_DYNAMIC_STITCH(type) \
|
#define REGISTER_DYNAMIC_STITCH(type) \
|
||||||
|
|
@ -221,12 +318,12 @@ struct ParallelDynamicStitchOp : DynamicStitchOpImpl<T, true> {
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<type>("T") \
|
.TypeConstraint<type>("T") \
|
||||||
.HostMemory("indices"), \
|
.HostMemory("indices"), \
|
||||||
DynamicStitchOp<type>) \
|
DynamicStitchOpCPU<type>) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \
|
REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<type>("T") \
|
.TypeConstraint<type>("T") \
|
||||||
.HostMemory("indices"), \
|
.HostMemory("indices"), \
|
||||||
ParallelDynamicStitchOp<type>)
|
ParallelDynamicStitchOpCPU<type>)
|
||||||
|
|
||||||
TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
|
TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
|
||||||
#undef REGISTER_DYNAMIC_STITCH
|
#undef REGISTER_DYNAMIC_STITCH
|
||||||
|
|
@ -236,19 +333,21 @@ TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
|
||||||
REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
|
REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
|
||||||
.Device(DEVICE_GPU) \
|
.Device(DEVICE_GPU) \
|
||||||
.TypeConstraint<type>("T") \
|
.TypeConstraint<type>("T") \
|
||||||
.HostMemory("indices") \
|
.HostMemory("indices"), \
|
||||||
.HostMemory("data") \
|
DynamicStitchOpGPU<type>) \
|
||||||
.HostMemory("merged"), \
|
|
||||||
DynamicStitchOp<type>) \
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \
|
REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \
|
||||||
.Device(DEVICE_GPU) \
|
.Device(DEVICE_GPU) \
|
||||||
.TypeConstraint<type>("T") \
|
.TypeConstraint<type>("T") \
|
||||||
.HostMemory("indices") \
|
.HostMemory("indices") \
|
||||||
.HostMemory("data") \
|
.HostMemory("data") \
|
||||||
.HostMemory("merged"), \
|
.HostMemory("merged"), \
|
||||||
ParallelDynamicStitchOp<type>)
|
ParallelDynamicStitchOpCPU<type>)
|
||||||
|
|
||||||
TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
|
||||||
|
TF_CALL_complex64(REGISTER_DYNAMIC_STITCH_GPU);
|
||||||
|
TF_CALL_complex128(REGISTER_DYNAMIC_STITCH_GPU);
|
||||||
|
TF_CALL_int64(REGISTER_DYNAMIC_STITCH_GPU);
|
||||||
|
TF_CALL_int32(REGISTER_DYNAMIC_STITCH_GPU);
|
||||||
#undef REGISTER_DYNAMIC_STITCH_GPU
|
#undef REGISTER_DYNAMIC_STITCH_GPU
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
|
||||||
81
tensorflow/core/kernels/dynamic_stitch_op_gpu.cu.cc
Normal file
81
tensorflow/core/kernels/dynamic_stitch_op_gpu.cu.cc
Normal file
|
|
@ -0,0 +1,81 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/kernels/cuda_device_array_gpu.h"
|
||||||
|
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
using GPUDevice = Eigen::GpuDevice;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void DynamicStitchKernel(const int32 slice_size,
|
||||||
|
const int32 output_size,
|
||||||
|
CudaDeviceArrayStruct<int32> input_indices,
|
||||||
|
CudaDeviceArrayStruct<const T*> input_ptrs,
|
||||||
|
T* output) {
|
||||||
|
int32* data_indices = GetCudaDeviceArrayOnDevice(&input_indices);
|
||||||
|
const T** data_ptrs = GetCudaDeviceArrayOnDevice(&input_ptrs);
|
||||||
|
CUDA_1D_KERNEL_LOOP(output_index, output_size) {
|
||||||
|
const int32 slice_id = output_index / slice_size;
|
||||||
|
const int32 slice_offset = output_index % slice_size;
|
||||||
|
const int32 input_index = data_indices[slice_id];
|
||||||
|
if (input_index != -1) {
|
||||||
|
output[output_index] = ldg(data_ptrs[input_index] + slice_offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
|
||||||
|
const int32 slice_size, const int32 first_dim_size,
|
||||||
|
const CudaDeviceArrayStruct<int>& input_indices,
|
||||||
|
const CudaDeviceArrayStruct<const T*>& input_ptrs,
|
||||||
|
T* output) {
|
||||||
|
const int32 output_size = first_dim_size * slice_size;
|
||||||
|
auto config = GetCudaLaunchConfig(output_size, gpu_device);
|
||||||
|
|
||||||
|
DynamicStitchKernel<T>
|
||||||
|
<<<config.block_count, config.thread_per_block, 0, gpu_device.stream()>>>(
|
||||||
|
slice_size, output_size, input_indices, input_ptrs, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define REGISTER_GPU(T) \
|
||||||
|
template void DynamicStitchGPUImpl( \
|
||||||
|
const Eigen::GpuDevice& gpu_device, const int32 slice_size, \
|
||||||
|
const int32 first_dim_size, \
|
||||||
|
const CudaDeviceArrayStruct<int32>& input_indices, \
|
||||||
|
const CudaDeviceArrayStruct<const T*>& input_ptrs, T* output);
|
||||||
|
|
||||||
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||||
|
TF_CALL_complex64(REGISTER_GPU);
|
||||||
|
TF_CALL_complex128(REGISTER_GPU);
|
||||||
|
TF_CALL_int64(REGISTER_GPU);
|
||||||
|
TF_CALL_int32(REGISTER_GPU)
|
||||||
|
|
||||||
|
#undef REGISTER_GPU
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.ops import data_flow_ops
|
from tensorflow.python.ops import data_flow_ops
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import
|
import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import
|
||||||
|
|
@ -33,7 +34,7 @@ class DynamicStitchTestBase(object):
|
||||||
self.stitch_op = stitch_op
|
self.stitch_op = stitch_op
|
||||||
|
|
||||||
def testScalar(self):
|
def testScalar(self):
|
||||||
with self.test_session():
|
with self.test_session(use_gpu=True):
|
||||||
indices = [constant_op.constant(0), constant_op.constant(1)]
|
indices = [constant_op.constant(0), constant_op.constant(1)]
|
||||||
data = [constant_op.constant(40), constant_op.constant(60)]
|
data = [constant_op.constant(40), constant_op.constant(60)]
|
||||||
for step in -1, 1:
|
for step in -1, 1:
|
||||||
|
|
@ -46,7 +47,7 @@ class DynamicStitchTestBase(object):
|
||||||
self.assertEqual([None], stitched_t.get_shape().as_list())
|
self.assertEqual([None], stitched_t.get_shape().as_list())
|
||||||
|
|
||||||
def testSimpleOneDimensional(self):
|
def testSimpleOneDimensional(self):
|
||||||
with self.test_session():
|
with self.test_session(use_gpu=True):
|
||||||
indices = [
|
indices = [
|
||||||
constant_op.constant([0, 4, 7]), constant_op.constant([1, 6, 2, 3, 5])
|
constant_op.constant([0, 4, 7]), constant_op.constant([1, 6, 2, 3, 5])
|
||||||
]
|
]
|
||||||
|
|
@ -63,7 +64,7 @@ class DynamicStitchTestBase(object):
|
||||||
self.assertEqual([None], stitched_t.get_shape().as_list())
|
self.assertEqual([None], stitched_t.get_shape().as_list())
|
||||||
|
|
||||||
def testOneListOneDimensional(self):
|
def testOneListOneDimensional(self):
|
||||||
with self.test_session():
|
with self.test_session(use_gpu=True):
|
||||||
indices = [constant_op.constant([1, 6, 2, 3, 5, 0, 4, 7])]
|
indices = [constant_op.constant([1, 6, 2, 3, 5, 0, 4, 7])]
|
||||||
data = [constant_op.constant([10, 60, 20, 30, 50, 0, 40, 70])]
|
data = [constant_op.constant([10, 60, 20, 30, 50, 0, 40, 70])]
|
||||||
stitched_t = self.stitch_op(indices, data)
|
stitched_t = self.stitch_op(indices, data)
|
||||||
|
|
@ -75,7 +76,7 @@ class DynamicStitchTestBase(object):
|
||||||
self.assertEqual([None], stitched_t.get_shape().as_list())
|
self.assertEqual([None], stitched_t.get_shape().as_list())
|
||||||
|
|
||||||
def testSimpleTwoDimensional(self):
|
def testSimpleTwoDimensional(self):
|
||||||
with self.test_session():
|
with self.test_session(use_gpu=True):
|
||||||
indices = [
|
indices = [
|
||||||
constant_op.constant([0, 4, 7]), constant_op.constant([1, 6]),
|
constant_op.constant([0, 4, 7]), constant_op.constant([1, 6]),
|
||||||
constant_op.constant([2, 3, 5])
|
constant_op.constant([2, 3, 5])
|
||||||
|
|
@ -95,7 +96,7 @@ class DynamicStitchTestBase(object):
|
||||||
self.assertEqual([None, 2], stitched_t.get_shape().as_list())
|
self.assertEqual([None, 2], stitched_t.get_shape().as_list())
|
||||||
|
|
||||||
def testHigherRank(self):
|
def testHigherRank(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session(use_gpu=True) as sess:
|
||||||
indices = [
|
indices = [
|
||||||
constant_op.constant(6), constant_op.constant([4, 1]),
|
constant_op.constant(6), constant_op.constant([4, 1]),
|
||||||
constant_op.constant([[5, 2], [0, 3]])
|
constant_op.constant([[5, 2], [0, 3]])
|
||||||
|
|
@ -176,6 +177,45 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
|
||||||
test.TestCase.__init__(self, *test_case_args)
|
test.TestCase.__init__(self, *test_case_args)
|
||||||
DynamicStitchTestBase.__init__(self, data_flow_ops.parallel_dynamic_stitch)
|
DynamicStitchTestBase.__init__(self, data_flow_ops.parallel_dynamic_stitch)
|
||||||
|
|
||||||
|
def testScalar(self):
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
indices = [constant_op.constant(0), constant_op.constant(1)]
|
||||||
|
data = [constant_op.constant(40.0), constant_op.constant(60.0)]
|
||||||
|
for step in -1, 1:
|
||||||
|
stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data)
|
||||||
|
stitched_val = stitched_t.eval()
|
||||||
|
self.assertAllEqual([40.0, 60.0][::step], stitched_val)
|
||||||
|
# Dimension 0 is determined by the max index in indices, so we
|
||||||
|
# can only infer that the output is a vector of some unknown
|
||||||
|
# length.
|
||||||
|
self.assertEqual([None], stitched_t.get_shape().as_list())
|
||||||
|
|
||||||
|
def testHigherRank(self):
|
||||||
|
with self.test_session(use_gpu=True) as sess:
|
||||||
|
indices = [
|
||||||
|
constant_op.constant(6),
|
||||||
|
constant_op.constant([4, 1]),
|
||||||
|
constant_op.constant([[5, 2], [0, 3]])
|
||||||
|
]
|
||||||
|
data = [
|
||||||
|
constant_op.constant([61, 62], dtype=dtypes.float32),
|
||||||
|
constant_op.constant([[41, 42], [11, 12]], dtype=dtypes.float32),
|
||||||
|
constant_op.constant(
|
||||||
|
[[[51, 52], [21, 22]], [[1, 2], [31, 32]]], dtype=dtypes.float32)
|
||||||
|
]
|
||||||
|
stitched_t = data_flow_ops.dynamic_stitch(indices, data)
|
||||||
|
stitched_val = stitched_t.eval()
|
||||||
|
correct = 10 * np.arange(7)[:, None] + [1.0, 2.0]
|
||||||
|
self.assertAllEqual(correct, stitched_val)
|
||||||
|
self.assertEqual([None, 2], stitched_t.get_shape().as_list())
|
||||||
|
# Test gradients
|
||||||
|
stitched_grad = 7 * stitched_val
|
||||||
|
grads = gradients_impl.gradients(stitched_t, indices + data,
|
||||||
|
stitched_grad)
|
||||||
|
self.assertEqual(grads[:3], [None] * 3) # Indices have no gradients
|
||||||
|
for datum, grad in zip(data, sess.run(grads[3:])):
|
||||||
|
self.assertAllEqual(7.0 * datum.eval(), grad)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user