Add support for bools in matrix_diag, matrix_diag_part, matrix_set_diag, matrix_band_part.

PiperOrigin-RevId: 157939272
This commit is contained in:
Eugene Brevdo 2017-06-03 18:22:52 -07:00 committed by TensorFlower Gardener
parent aad2e3daff
commit 7ffc357325
9 changed files with 88 additions and 43 deletions

View File

@ -83,7 +83,7 @@ class MatrixBandPartOp : public OpKernel {
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER( \
Name("MatrixBandPart").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ Name("MatrixBandPart").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
MatrixBandPartOp<CPUDevice, type>); MatrixBandPartOp<CPUDevice, type>);
TF_CALL_NUMBER_TYPES(REGISTER_MATRIX_BAND_PART); TF_CALL_POD_TYPES(REGISTER_MATRIX_BAND_PART);
#undef REGISTER_MATRIX_BAND_PART #undef REGISTER_MATRIX_BAND_PART
// Registration of the deprecated kernel. // Registration of the deprecated kernel.
@ -143,6 +143,7 @@ namespace functor {
extern template struct MatrixBandPart<GPUDevice, T>; extern template struct MatrixBandPart<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_bool(DECLARE_GPU_SPEC);
TF_CALL_complex64(DECLARE_GPU_SPEC); TF_CALL_complex64(DECLARE_GPU_SPEC);
TF_CALL_complex128(DECLARE_GPU_SPEC); TF_CALL_complex128(DECLARE_GPU_SPEC);
} // namespace functor } // namespace functor
@ -156,6 +157,7 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
.HostMemory("num_upper"), \ .HostMemory("num_upper"), \
MatrixBandPartOp<GPUDevice, type>); MatrixBandPartOp<GPUDevice, type>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_BAND_PART_GPU); TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_BAND_PART_GPU);
TF_CALL_bool(REGISTER_MATRIX_BAND_PART_GPU);
TF_CALL_complex64(REGISTER_MATRIX_BAND_PART_GPU); TF_CALL_complex64(REGISTER_MATRIX_BAND_PART_GPU);
TF_CALL_complex128(REGISTER_MATRIX_BAND_PART_GPU); TF_CALL_complex128(REGISTER_MATRIX_BAND_PART_GPU);
#undef REGISTER_MATRIX_BAND_PART_GPU #undef REGISTER_MATRIX_BAND_PART_GPU

View File

@ -29,6 +29,7 @@ typedef Eigen::GpuDevice GPUDevice;
template struct functor::MatrixBandPart<GPUDevice, T>; template struct functor::MatrixBandPart<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC); TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
TF_CALL_bool(DEFINE_GPU_SPEC);
TF_CALL_complex64(DEFINE_GPU_SPEC); TF_CALL_complex64(DEFINE_GPU_SPEC);
TF_CALL_complex128(DEFINE_GPU_SPEC); TF_CALL_complex128(DEFINE_GPU_SPEC);

View File

@ -123,7 +123,7 @@ class MatrixDiagOp : public OpKernel {
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER( \
Name("MatrixDiagPart").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ Name("MatrixDiagPart").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
MatrixDiagPartOp<CPUDevice, type>); MatrixDiagPartOp<CPUDevice, type>);
TF_CALL_NUMBER_TYPES(REGISTER_MATRIX_DIAG); TF_CALL_POD_TYPES(REGISTER_MATRIX_DIAG);
#undef REGISTER_MATRIX_DIAG #undef REGISTER_MATRIX_DIAG
// Registration of the deprecated kernel. // Registration of the deprecated kernel.
@ -136,7 +136,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_MATRIX_DIAG);
.Device(DEVICE_CPU) \ .Device(DEVICE_CPU) \
.TypeConstraint<type>("T"), \ .TypeConstraint<type>("T"), \
MatrixDiagPartOp<CPUDevice, type>); MatrixDiagPartOp<CPUDevice, type>);
TF_CALL_NUMBER_TYPES(REGISTER_BATCH_MATRIX_DIAG); TF_CALL_POD_TYPES(REGISTER_BATCH_MATRIX_DIAG);
#undef REGISTER_BATCH_MATRIX_DIAG #undef REGISTER_BATCH_MATRIX_DIAG
// Implementation of the functor specialization for CPU. // Implementation of the functor specialization for CPU.
@ -187,6 +187,7 @@ namespace functor {
extern template struct MatrixDiagPart<GPUDevice, T>; extern template struct MatrixDiagPart<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_bool(DECLARE_GPU_SPEC);
TF_CALL_complex64(DECLARE_GPU_SPEC); TF_CALL_complex64(DECLARE_GPU_SPEC);
TF_CALL_complex128(DECLARE_GPU_SPEC); TF_CALL_complex128(DECLARE_GPU_SPEC);
@ -201,6 +202,7 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
Name("MatrixDiagPart").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ Name("MatrixDiagPart").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
MatrixDiagPartOp<GPUDevice, type>); MatrixDiagPartOp<GPUDevice, type>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_DIAG_GPU); TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_DIAG_GPU);
TF_CALL_bool(REGISTER_MATRIX_DIAG_GPU);
TF_CALL_complex64(REGISTER_MATRIX_DIAG_GPU); TF_CALL_complex64(REGISTER_MATRIX_DIAG_GPU);
TF_CALL_complex128(REGISTER_MATRIX_DIAG_GPU); TF_CALL_complex128(REGISTER_MATRIX_DIAG_GPU);
#undef REGISTER_MATRIX_DIAG_GPU #undef REGISTER_MATRIX_DIAG_GPU

View File

@ -31,6 +31,7 @@ typedef Eigen::GpuDevice GPUDevice;
template struct functor::MatrixDiagPart<GPUDevice, T>; template struct functor::MatrixDiagPart<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC); TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
TF_CALL_bool(DEFINE_GPU_SPEC);
TF_CALL_complex64(DEFINE_GPU_SPEC); TF_CALL_complex64(DEFINE_GPU_SPEC);
TF_CALL_complex128(DEFINE_GPU_SPEC); TF_CALL_complex128(DEFINE_GPU_SPEC);

View File

@ -100,7 +100,7 @@ class MatrixSetDiagOp : public OpKernel {
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER( \
Name("MatrixSetDiag").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ Name("MatrixSetDiag").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
MatrixSetDiagOp<CPUDevice, type>); MatrixSetDiagOp<CPUDevice, type>);
TF_CALL_NUMBER_TYPES(REGISTER_MATRIX_SET_DIAG); TF_CALL_POD_TYPES(REGISTER_MATRIX_SET_DIAG);
#undef REGISTER_MATRIX_SET_DIAG #undef REGISTER_MATRIX_SET_DIAG
// Registration of the deprecated kernel. // Registration of the deprecated kernel.
@ -109,7 +109,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_MATRIX_SET_DIAG);
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER( \
Name("BatchMatrixSetDiag").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ Name("BatchMatrixSetDiag").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
MatrixSetDiagOp<CPUDevice, type>); MatrixSetDiagOp<CPUDevice, type>);
TF_CALL_NUMBER_TYPES(REGISTER_BATCH_MATRIX_SET_DIAG); TF_CALL_POD_TYPES(REGISTER_BATCH_MATRIX_SET_DIAG);
#undef REGISTER_BATCH_MATRIX_SET_DIAG #undef REGISTER_BATCH_MATRIX_SET_DIAG
namespace functor { namespace functor {
@ -131,6 +131,21 @@ struct MatrixSetDiag<CPUDevice, T> {
} }
}; };
template <>
struct MatrixSetDiag<CPUDevice, bool> {
static void Compute(const CPUDevice& d, TTypes<bool, 3>::ConstTensor input,
TTypes<bool, 2>::ConstTensor diag,
TTypes<bool>::Scalar scratch,
TTypes<bool, 3>::Tensor output) {
output.device(d) = input;
for (int64 r = 0; r < output.dimension(0); ++r) {
for (int64 d = 0; d < diag.dimension(1); ++d) {
output(r, d, d) = diag(r, d);
}
}
}
};
} // namespace functor } // namespace functor
#if GOOGLE_CUDA #if GOOGLE_CUDA
@ -147,6 +162,7 @@ namespace functor {
extern template struct MatrixSetDiag<GPUDevice, T>; extern template struct MatrixSetDiag<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_bool(DECLARE_GPU_SPEC);
TF_CALL_complex64(DECLARE_GPU_SPEC); TF_CALL_complex64(DECLARE_GPU_SPEC);
TF_CALL_complex128(DECLARE_GPU_SPEC); TF_CALL_complex128(DECLARE_GPU_SPEC);
@ -158,6 +174,7 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
Name("MatrixSetDiag").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ Name("MatrixSetDiag").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
MatrixSetDiagOp<GPUDevice, type>); MatrixSetDiagOp<GPUDevice, type>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_SET_DIAG_GPU); TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_SET_DIAG_GPU);
TF_CALL_bool(REGISTER_MATRIX_SET_DIAG_GPU);
TF_CALL_complex64(REGISTER_MATRIX_SET_DIAG_GPU); TF_CALL_complex64(REGISTER_MATRIX_SET_DIAG_GPU);
TF_CALL_complex128(REGISTER_MATRIX_SET_DIAG_GPU); TF_CALL_complex128(REGISTER_MATRIX_SET_DIAG_GPU);
#undef REGISTER_MATRIX_SET_DIAG_GPU #undef REGISTER_MATRIX_SET_DIAG_GPU

View File

@ -71,6 +71,23 @@ struct MatrixSetDiag {
} }
}; };
template <typename Device>
struct MatrixSetDiag<Device, bool> {
EIGEN_ALWAYS_INLINE static void Compute(const Device& d,
TTypes<bool, 3>::ConstTensor input,
TTypes<bool, 2>::ConstTensor diag,
TTypes<bool>::Scalar scratch,
TTypes<bool, 3>::Tensor output) {
output.device(d) = input;
generator::OverwriteDiagGenerator<bool> generator(diag, output);
// Use all() to force the generation to aggregate to the scalar
// output scratch. This in turn forces each element of the
// generator to execute. The side effect of the execution is to
// update the diagonal components of output with diag.
scratch.device(d) = diag.generate(generator).all();
}
};
} // namespace functor } // namespace functor
} // namespace tensorflow } // namespace tensorflow

View File

@ -29,6 +29,7 @@ typedef Eigen::GpuDevice GPUDevice;
template struct functor::MatrixSetDiag<GPUDevice, T>; template struct functor::MatrixSetDiag<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC); TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
TF_CALL_bool(DEFINE_GPU_SPEC);
TF_CALL_complex64(DEFINE_GPU_SPEC); TF_CALL_complex64(DEFINE_GPU_SPEC);
TF_CALL_complex128(DEFINE_GPU_SPEC); TF_CALL_complex128(DEFINE_GPU_SPEC);

View File

@ -938,7 +938,7 @@ cuda_py_test(
tags = ["notsan"], tags = ["notsan"],
) )
tf_py_test( cuda_py_test(
name = "diag_op_test", name = "diag_op_test",
size = "medium", size = "medium",
srcs = ["diag_op_test.py"], srcs = ["diag_op_test.py"],

View File

@ -39,21 +39,23 @@ class MatrixDiagTest(test.TestCase):
self.assertEqual((3, 3), v_diag.get_shape()) self.assertEqual((3, 3), v_diag.get_shape())
self.assertAllEqual(v_diag.eval(), mat) self.assertAllEqual(v_diag.eval(), mat)
def testBatchVector(self): def _testBatchVector(self, dtype):
with self.test_session(use_gpu=True): with self.test_session(use_gpu=True):
v_batch = np.array([[1.0, 2.0, 3.0], v_batch = np.array([[1.0, 0.0, 3.0], [4.0, 5.0, 6.0]]).astype(dtype)
[4.0, 5.0, 6.0]]) mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]],
mat_batch = np.array( [[4.0, 0.0, 0.0], [0.0, 5.0, 0.0],
[[[1.0, 0.0, 0.0], [0.0, 0.0, 6.0]]]).astype(dtype)
[0.0, 2.0, 0.0],
[0.0, 0.0, 3.0]],
[[4.0, 0.0, 0.0],
[0.0, 5.0, 0.0],
[0.0, 0.0, 6.0]]])
v_batch_diag = array_ops.matrix_diag(v_batch) v_batch_diag = array_ops.matrix_diag(v_batch)
self.assertEqual((2, 3, 3), v_batch_diag.get_shape()) self.assertEqual((2, 3, 3), v_batch_diag.get_shape())
self.assertAllEqual(v_batch_diag.eval(), mat_batch) self.assertAllEqual(v_batch_diag.eval(), mat_batch)
def testBatchVector(self):
self._testBatchVector(np.float32)
self._testBatchVector(np.float64)
self._testBatchVector(np.int32)
self._testBatchVector(np.int64)
self._testBatchVector(np.bool)
def testInvalidShape(self): def testInvalidShape(self):
with self.assertRaisesRegexp(ValueError, "must be at least rank 1"): with self.assertRaisesRegexp(ValueError, "must be at least rank 1"):
array_ops.matrix_diag(0) array_ops.matrix_diag(0)
@ -108,29 +110,29 @@ class MatrixSetDiagTest(test.TestCase):
self.assertEqual((3, 2), output.get_shape()) self.assertEqual((3, 2), output.get_shape())
self.assertAllEqual(expected, output.eval()) self.assertAllEqual(expected, output.eval())
def testSquareBatch(self): def _testSquareBatch(self, dtype):
with self.test_session(use_gpu=True): with self.test_session(use_gpu=True):
v_batch = np.array([[-1.0, -2.0, -3.0], v_batch = np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]]).astype(dtype)
[-4.0, -5.0, -6.0]]) mat_batch = np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]],
mat_batch = np.array( [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0],
[[[1.0, 0.0, 3.0], [2.0, 0.0, 6.0]]]).astype(dtype)
[0.0, 2.0, 0.0],
[1.0, 0.0, 3.0]], mat_set_diag_batch = np.array([[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0],
[[4.0, 0.0, 4.0], [1.0, 0.0, -3.0]],
[0.0, 5.0, 0.0], [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0],
[2.0, 0.0, 6.0]]]) [2.0, 0.0, -6.0]]]).astype(dtype)
mat_set_diag_batch = np.array(
[[[-1.0, 0.0, 3.0],
[0.0, -2.0, 0.0],
[1.0, 0.0, -3.0]],
[[-4.0, 0.0, 4.0],
[0.0, -5.0, 0.0],
[2.0, 0.0, -6.0]]])
output = array_ops.matrix_set_diag(mat_batch, v_batch) output = array_ops.matrix_set_diag(mat_batch, v_batch)
self.assertEqual((2, 3, 3), output.get_shape()) self.assertEqual((2, 3, 3), output.get_shape())
self.assertAllEqual(mat_set_diag_batch, output.eval()) self.assertAllEqual(mat_set_diag_batch, output.eval())
def testSquareBatch(self):
self._testSquareBatch(np.float32)
self._testSquareBatch(np.float64)
self._testSquareBatch(np.int32)
self._testSquareBatch(np.int64)
self._testSquareBatch(np.bool)
def testRectangularBatch(self): def testRectangularBatch(self):
with self.test_session(use_gpu=True): with self.test_session(use_gpu=True):
v_batch = np.array([[-1.0, -2.0], v_batch = np.array([[-1.0, -2.0],
@ -220,22 +222,24 @@ class MatrixDiagPartTest(test.TestCase):
mat_diag = array_ops.matrix_diag_part(mat) mat_diag = array_ops.matrix_diag_part(mat)
self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0])) self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0]))
def testSquareBatch(self): def _testSquareBatch(self, dtype):
with self.test_session(use_gpu=True): with self.test_session(use_gpu=True):
v_batch = np.array([[1.0, 2.0, 3.0], v_batch = np.array([[1.0, 0.0, 3.0], [4.0, 5.0, 6.0]]).astype(dtype)
[4.0, 5.0, 6.0]]) mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]],
mat_batch = np.array( [[4.0, 0.0, 0.0], [0.0, 5.0, 0.0],
[[[1.0, 0.0, 0.0], [0.0, 0.0, 6.0]]]).astype(dtype)
[0.0, 2.0, 0.0],
[0.0, 0.0, 3.0]],
[[4.0, 0.0, 0.0],
[0.0, 5.0, 0.0],
[0.0, 0.0, 6.0]]])
self.assertEqual(mat_batch.shape, (2, 3, 3)) self.assertEqual(mat_batch.shape, (2, 3, 3))
mat_batch_diag = array_ops.matrix_diag_part(mat_batch) mat_batch_diag = array_ops.matrix_diag_part(mat_batch)
self.assertEqual((2, 3), mat_batch_diag.get_shape()) self.assertEqual((2, 3), mat_batch_diag.get_shape())
self.assertAllEqual(mat_batch_diag.eval(), v_batch) self.assertAllEqual(mat_batch_diag.eval(), v_batch)
def testSquareBatch(self):
self._testSquareBatch(np.float32)
self._testSquareBatch(np.float64)
self._testSquareBatch(np.int32)
self._testSquareBatch(np.int64)
self._testSquareBatch(np.bool)
def testRectangularBatch(self): def testRectangularBatch(self):
with self.test_session(use_gpu=True): with self.test_session(use_gpu=True):
v_batch = np.array([[1.0, 2.0], v_batch = np.array([[1.0, 2.0],