Factor out DenseUpdate ops into dense_update_functor build dep.

Also add support for complex types.

PiperOrigin-RevId: 161726749
This commit is contained in:
Eugene Brevdo 2017-07-12 15:24:40 -07:00 committed by TensorFlower Gardener
parent 9a172989e0
commit c65f691196
10 changed files with 164 additions and 127 deletions

View File

@ -172,7 +172,9 @@ limitations under the License.
TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m)
// Call "m" on all types supported on GPU.
#define TF_CALL_GPU_ALL_TYPES(m) TF_CALL_GPU_NUMBER_TYPES(m) TF_CALL_bool(m)
#define TF_CALL_GPU_ALL_TYPES(m) \
TF_CALL_GPU_NUMBER_TYPES(m) \
TF_CALL_bool(m) TF_CALL_complex64(m) TF_CALL_complex128(m)
#define TF_CALL_GPU_NUMBER_TYPES_NO_HALF(m) TF_CALL_float(m) TF_CALL_double(m)

View File

@ -19,6 +19,7 @@ package_group(
name = "friends",
packages = [
"//learning/brain/contrib/...",
"//learning/brain/research/sparse_matrix/...",
"//tensorflow/...",
],
)
@ -94,13 +95,11 @@ tf_kernel_library(
"strided_slice_op_inst_7.cc",
],
hdrs = [
"dense_update_ops.h",
"slice_op.h",
"strided_slice_op.h",
"strided_slice_op_impl.h",
],
gpu_srcs = [
"dense_update_ops.h",
"slice_op.h",
"strided_slice_op.h",
"strided_slice_op_impl.h",
@ -109,6 +108,7 @@ tf_kernel_library(
],
deps = [
":bounds_check",
":dense_update_functor",
":ops_util",
":variable_ops",
"//tensorflow/core:framework",
@ -1011,6 +1011,23 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "dense_update_functor",
srcs = ["dense_update_functor.cc"],
hdrs = ["dense_update_functor.h"],
gpu_srcs = [
"dense_update_functor.h",
"dense_update_functor_gpu.cu.cc",
],
visibility = [":friends"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/eigen3",
],
alwayslink = 0,
)
tf_cuda_cc_test(
name = "gather_op_test",
size = "small",
@ -1606,7 +1623,7 @@ tf_kernel_library(
srcs = ["resource_variable_ops.cc"],
deps = [
":bounds_check",
":dense_update_ops",
":dense_update_functor",
":gather_functor",
":scatter_functor",
":state",
@ -2079,7 +2096,7 @@ tf_kernel_library(
"//tensorflow:darwin": [],
"//conditions:default": ["-Wl,-z,muldefs"],
}),
visibility = ["//visibility:private"],
visibility = [":friends"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -3482,7 +3499,7 @@ tf_kernel_library(
tf_kernel_library(
name = "dense_update_ops",
prefix = "dense_update_ops",
deps = STATE_DEPS,
deps = STATE_DEPS + [":dense_update_functor"],
)
tf_kernel_library(
@ -3503,16 +3520,14 @@ tf_kernel_library(
"scatter_nd_op_cpu_impl_5.cc",
],
hdrs = [
"dense_update_ops.h",
"scatter_nd_op.h",
"scatter_nd_op_cpu_impl.h",
],
gpu_srcs = [
"dense_update_ops.h",
"scatter_nd_op.h",
"scatter_nd_op_gpu.cu.cc",
],
deps = STATE_DEPS + [":dense_update_ops"],
deps = STATE_DEPS + [":dense_update_functor"],
)
tf_kernel_library(
@ -4054,8 +4069,9 @@ filegroup(
"cwise_ops_common.cc",
"cwise_ops_common.h",
"cwise_ops_gradients.h",
"dense_update_functor.cc",
"dense_update_functor.h",
"dense_update_ops.cc",
"dense_update_ops.h",
"example_parsing_ops.cc",
"fill_functor.cc",
"fill_functor.h",

View File

@ -0,0 +1,73 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
namespace functor {
template <>
struct DenseUpdate<CPUDevice, string, ASSIGN> {
void operator()(const CPUDevice& d, typename TTypes<string>::Flat params,
typename TTypes<string>::ConstFlat update) {
if (params.dimension(0) == 1) {
params.data()->resize(update.data()->size());
auto work = [&params, &update](int64 start, int64 end) {
memmove(const_cast<char*>(params.data()->data()) + start,
update.data()->data() + start, end - start);
};
d.parallelFor(update.data()->size(),
Eigen::TensorOpCost(.1, // chosen to force large chunks
.1, 0),
work);
} else {
auto work = [&params, &update](int64 start, int64 end) {
for (int i = start; i < end; ++i) {
params.data()[i].resize(update.data()[i].size());
memmove(const_cast<char*>(params.data()[i].data()),
update.data()[i].data(), update.data()[i].size());
}
};
int64 estimated_string_size;
if (update.size() > 0) {
// first element of the tensor seems as good a guess as any of the sizes
// of the strings contained within...
estimated_string_size =
std::max(update.data()[0].size(), sizeof(string));
} else {
estimated_string_size = sizeof(string);
}
d.parallelFor(
params.dimension(0),
Eigen::TensorOpCost(estimated_string_size, estimated_string_size, 0),
work);
}
}
};
} // namespace functor
} // namespace tensorflow

View File

@ -13,40 +13,47 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_KERNELS_DENSE_UPDATE_OPS_H_
#define TENSORFLOW_KERNELS_DENSE_UPDATE_OPS_H_
#ifndef TENSORFLOW_KERNELS_DENSE_UPDATE_FUNCTOR_H_
#define TENSORFLOW_KERNELS_DENSE_UPDATE_FUNCTOR_H_
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
enum DenseUpdateType { ADD, SUB, ASSIGN };
namespace functor {
template <typename Device, typename T, DenseUpdateType OP>
struct DenseUpdate;
template <typename Device, typename T>
struct DenseUpdate<Device, T, ADD> {
struct DenseUpdate {
void operator()(const Device& d, typename TTypes<T>::Flat params,
typename TTypes<T>::ConstFlat update);
};
template <typename T>
struct DenseUpdate<CPUDevice, T, ADD> {
void operator()(const CPUDevice& d, typename TTypes<T>::Flat params,
typename TTypes<T>::ConstFlat update) {
params.device(d) += update;
}
};
template <typename Device, typename T>
struct DenseUpdate<Device, T, SUB> {
void operator()(const Device& d, typename TTypes<T>::Flat params,
template <typename T>
struct DenseUpdate<CPUDevice, T, SUB> {
void operator()(const CPUDevice& d, typename TTypes<T>::Flat params,
typename TTypes<T>::ConstFlat update) {
params.device(d) -= update;
}
};
template <typename Device, typename T>
struct DenseUpdate<Device, T, ASSIGN> {
void operator()(const Device& d, typename TTypes<T>::Flat params,
template <typename T>
struct DenseUpdate<CPUDevice, T, ASSIGN> {
void operator()(const CPUDevice& d, typename TTypes<T>::Flat params,
typename TTypes<T>::ConstFlat update) {
params.device(d) = update;
}
@ -55,4 +62,4 @@ struct DenseUpdate<Device, T, ASSIGN> {
} // end namespace functor
} // end namespace tensorflow
#endif // TENSORFLOW_KERNELS_DENSE_UPDATE_OPS_H_
#endif // TENSORFLOW_KERNELS_DENSE_UPDATE_FUNCTOR_H_

View File

@ -17,7 +17,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/dense_update_ops.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/framework/register_types.h"
@ -25,6 +25,34 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
namespace functor {
template <typename T>
struct DenseUpdate<GPUDevice, T, ASSIGN> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat params,
typename TTypes<T>::ConstFlat update) {
params.device(d) = update;
}
};
template <typename T>
struct DenseUpdate<GPUDevice, T, ADD> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat params,
typename TTypes<T>::ConstFlat update) {
params.device(d) += update;
}
};
template <typename T>
struct DenseUpdate<GPUDevice, T, SUB> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat params,
typename TTypes<T>::ConstFlat update) {
params.device(d) -= update;
}
};
} // namespace functor
#define DEFINE_GPU_KERNELS(T) \
template struct functor::DenseUpdate<GPUDevice, T, ADD>; \
template struct functor::DenseUpdate<GPUDevice, T, SUB>;

View File

@ -15,59 +15,20 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/dense_update_ops.h"
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/assign_op.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace functor {
template <>
struct DenseUpdate<Eigen::ThreadPoolDevice, string, ASSIGN> {
void operator()(const Eigen::ThreadPoolDevice& d,
typename TTypes<string>::Flat params,
typename TTypes<string>::ConstFlat update) {
if (params.dimension(0) == 1) {
params.data()->resize(update.data()->size());
auto work = [&params, &update](int64 start, int64 end) {
memmove(const_cast<char*>(params.data()->data()) + start,
update.data()->data() + start, end - start);
};
d.parallelFor(update.data()->size(),
Eigen::TensorOpCost(.1, // chosen to force large chunks
.1, 0),
work);
} else {
auto work = [&params, &update](int64 start, int64 end) {
for (int i = start; i < end; ++i) {
params.data()[i].resize(update.data()[i].size());
memmove(const_cast<char*>(params.data()[i].data()),
update.data()[i].data(), update.data()[i].size());
}
};
int64 estimated_string_size;
if (update.size() > 0) {
// first element of the tensor seems as good a guess as any of the sizes
// of the strings contained within...
estimated_string_size =
std::max(update.data()[0].size(), sizeof(string));
} else {
estimated_string_size = sizeof(string);
}
d.parallelFor(
params.dimension(0),
Eigen::TensorOpCost(estimated_string_size, estimated_string_size, 0),
work);
}
}
};
} // namespace functor
template <typename Device, typename T>
class AssignOpT : public AssignOp {
@ -117,7 +78,7 @@ class DenseUpdateOp : public OpKernel {
errors::InvalidArgument("Parameters and update must be the same size"));
functor::DenseUpdate<Device, T, OP> update_functor;
update_functor(context->eigen_device<Device>(), Tparams.flat<T>(),
update_functor(context->template eigen_device<Device>(), Tparams.flat<T>(),
Tupdate.flat<T>());
}
@ -143,13 +104,6 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
// Only register 'Assign' on GPU for the subset of types also supported by
// 'Variable' (see variable_ops.cc.)
#define REGISTER_GPU_KERNELS(type) \
namespace functor { \
template <> \
void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \
const GPUDevice& d, typename TTypes<type>::Flat lhs, \
typename TTypes<type>::ConstFlat rhs); \
extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
} \
REGISTER_KERNEL_BUILDER( \
Name("Assign").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
AssignOpT<GPUDevice, type>);
@ -180,22 +134,6 @@ TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC_FOR_OP(T, OP) \
template <> \
void DenseUpdate<GPUDevice, T, OP>::operator()( \
const GPUDevice& d, typename TTypes<T>::Flat params, \
typename TTypes<T>::ConstFlat update); \
extern template struct DenseUpdate<GPUDevice, T, OP>;
#define DECLARE_GPU_SPEC(T) \
DECLARE_GPU_SPEC_FOR_OP(T, DenseUpdateType::ADD); \
DECLARE_GPU_SPEC_FOR_OP(T, DenseUpdateType::SUB)
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
#undef DECLARE_GPU_SPEC
#undef DECLARE_GPU_SPEC_FOR_OP
} // namespace functor
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("AssignAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"), \

View File

@ -15,12 +15,16 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/dense_update_ops.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/gather_functor.h"
#include "tensorflow/core/kernels/scatter_functor.h"
#include "tensorflow/core/kernels/variable_ops.h"
@ -217,13 +221,6 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
#if GOOGLE_CUDA
#define REGISTER_GPU_KERNELS(type) \
namespace functor { \
template <> \
void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \
const GPUDevice& d, typename TTypes<type>::Flat lhs, \
typename TTypes<type>::ConstFlat rhs); \
extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
} \
REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("dtype") \
@ -275,20 +272,6 @@ TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
#if GOOGLE_CUDA
#define REGISTER_GPU_KERNELS(type) \
namespace functor { \
template <> \
void DenseUpdate<GPUDevice, type, ADD>::operator()( \
const GPUDevice& d, typename TTypes<type>::Flat lhs, \
typename TTypes<type>::ConstFlat rhs); \
extern template struct DenseUpdate<GPUDevice, type, ADD>; \
} \
namespace functor { \
template <> \
void DenseUpdate<GPUDevice, type, SUB>::operator()( \
const GPUDevice& d, typename TTypes<type>::Flat lhs, \
typename TTypes<type>::ConstFlat rhs); \
extern template struct DenseUpdate<GPUDevice, type, SUB>; \
} \
REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp") \
.Device(DEVICE_GPU) \
.HostMemory("resource") \

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/dense_update_ops.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/mutex.h"
@ -523,16 +523,6 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS);
#undef DECLARE_GPU_SPECS_INDEX
#undef DECLARE_GPU_SPECS_INDEX_OP
#define REGISTER_GPU_KERNELS(type) \
template <> \
void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \
const GPUDevice& d, typename TTypes<type>::Flat lhs, \
typename TTypes<type>::ConstFlat rhs); \
extern template struct DenseUpdate<GPUDevice, type, ASSIGN>;
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
} // namespace functor
#endif // GOOGLE_CUDA

View File

@ -22,7 +22,7 @@ limitations under the License.
#endif // GOOGLE_CUDA
#include "tensorflow/core/kernels/strided_slice_op.h"
#include "tensorflow/core/kernels/dense_update_ops.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/slice_op.h"
#include "tensorflow/core/kernels/strided_slice_op_impl.h"

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types_traits.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/dense_update_ops.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"