mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
[OpenCL] Cleans dense_update_ops (#10335)
* [OpenCL] Cleans dense_update_ops * Acts on feedback from: https://github.com/tensorflow/tensorflow/pull/10335#discussion_r120536460
This commit is contained in:
parent
85f9681258
commit
036ce8ba64
|
|
@ -126,6 +126,9 @@ class DenseUpdateOp : public OpKernel {
|
||||||
|
|
||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
typedef Eigen::SyclDevice SYCLDevice;
|
||||||
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
#define REGISTER_KERNELS(type) \
|
#define REGISTER_KERNELS(type) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
|
@ -136,26 +139,6 @@ TF_CALL_ALL_TYPES(REGISTER_KERNELS);
|
||||||
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
|
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
|
||||||
#undef REGISTER_KERNELS
|
#undef REGISTER_KERNELS
|
||||||
|
|
||||||
#if TENSORFLOW_USE_SYCL
|
|
||||||
typedef Eigen::SyclDevice SYCLDevice;
|
|
||||||
#define REGISTER_SYCL_KERNEL(type) \
|
|
||||||
REGISTER_KERNEL_BUILDER( \
|
|
||||||
Name("Assign") \
|
|
||||||
.Device(DEVICE_SYCL) \
|
|
||||||
.TypeConstraint<type>("T"), \
|
|
||||||
AssignOpT<SYCLDevice, type>); \
|
|
||||||
REGISTER_KERNEL_BUILDER( \
|
|
||||||
Name("AssignAdd").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
|
|
||||||
DenseUpdateOp<SYCLDevice, type, DenseUpdateType::ADD>); \
|
|
||||||
REGISTER_KERNEL_BUILDER( \
|
|
||||||
Name("AssignSub").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
|
|
||||||
DenseUpdateOp<SYCLDevice, type, DenseUpdateType::SUB>);
|
|
||||||
|
|
||||||
REGISTER_SYCL_KERNEL(float);
|
|
||||||
REGISTER_SYCL_KERNEL(double);
|
|
||||||
#undef REGISTER_SYCL_KERNEL
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
// Only register 'Assign' on GPU for the subset of types also supported by
|
// Only register 'Assign' on GPU for the subset of types also supported by
|
||||||
// 'Variable' (see variable_ops.cc.)
|
// 'Variable' (see variable_ops.cc.)
|
||||||
|
|
@ -175,6 +158,16 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||||
#undef REGISTER_GPU_KERNELS
|
#undef REGISTER_GPU_KERNELS
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
#define REGISTER_SYCL_KERNELS(type) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("Assign").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
|
||||||
|
AssignOpT<SYCLDevice, type>);
|
||||||
|
|
||||||
|
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS);
|
||||||
|
#undef REGISTER_SYCL_KERNELS
|
||||||
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
#define REGISTER_KERNELS(type) \
|
#define REGISTER_KERNELS(type) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("AssignAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
Name("AssignAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||||
|
|
@ -214,4 +207,16 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||||
#undef REGISTER_GPU_KERNELS
|
#undef REGISTER_GPU_KERNELS
|
||||||
#endif // end GOOGLE_CUDA
|
#endif // end GOOGLE_CUDA
|
||||||
|
|
||||||
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
#define REGISTER_SYCL_KERNELS(type) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("AssignAdd").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
|
||||||
|
DenseUpdateOp<SYCLDevice, type, DenseUpdateType::ADD>); \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("AssignSub").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
|
||||||
|
DenseUpdateOp<SYCLDevice, type, DenseUpdateType::SUB>);
|
||||||
|
|
||||||
|
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS);
|
||||||
|
#undef REGISTER_SYCL_KERNELS
|
||||||
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user