[OpenCL] Cleans dense_update_ops (#10335)

* [OpenCL] Cleans dense_update_ops

* Acts on feedback from: https://github.com/tensorflow/tensorflow/pull/10335#discussion_r120536460
This commit is contained in:
Luke Iwanski 2017-06-07 22:27:57 +01:00 committed by gunan
parent 85f9681258
commit 036ce8ba64

View File

@ -126,6 +126,9 @@ class DenseUpdateOp : public OpKernel {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
@ -136,26 +139,6 @@ TF_CALL_ALL_TYPES(REGISTER_KERNELS);
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#if TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#define REGISTER_SYCL_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("Assign") \
.Device(DEVICE_SYCL) \
.TypeConstraint<type>("T"), \
AssignOpT<SYCLDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("AssignAdd").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
DenseUpdateOp<SYCLDevice, type, DenseUpdateType::ADD>); \
REGISTER_KERNEL_BUILDER( \
Name("AssignSub").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
DenseUpdateOp<SYCLDevice, type, DenseUpdateType::SUB>);
REGISTER_SYCL_KERNEL(float);
REGISTER_SYCL_KERNEL(double);
#undef REGISTER_SYCL_KERNEL
#endif
#if GOOGLE_CUDA
// Only register 'Assign' on GPU for the subset of types also supported by
// 'Variable' (see variable_ops.cc.)
@ -175,6 +158,16 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Assign").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
AssignOpT<SYCLDevice, type>);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS);
#undef REGISTER_SYCL_KERNELS
#endif // TENSORFLOW_USE_SYCL
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("AssignAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
@ -214,4 +207,16 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
#endif // end GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("AssignAdd").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
DenseUpdateOp<SYCLDevice, type, DenseUpdateType::ADD>); \
REGISTER_KERNEL_BUILDER( \
Name("AssignSub").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
DenseUpdateOp<SYCLDevice, type, DenseUpdateType::SUB>);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS);
#undef REGISTER_SYCL_KERNELS
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow