mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Add support for bools in matrix_diag, matrix_diag_part, matrix_set_diag, matrix_band_part.
PiperOrigin-RevId: 157939272
This commit is contained in:
parent
aad2e3daff
commit
7ffc357325
|
|
@ -83,7 +83,7 @@ class MatrixBandPartOp : public OpKernel {
|
|||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MatrixBandPart").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||
MatrixBandPartOp<CPUDevice, type>);
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_MATRIX_BAND_PART);
|
||||
TF_CALL_POD_TYPES(REGISTER_MATRIX_BAND_PART);
|
||||
#undef REGISTER_MATRIX_BAND_PART
|
||||
|
||||
// Registration of the deprecated kernel.
|
||||
|
|
@ -143,6 +143,7 @@ namespace functor {
|
|||
extern template struct MatrixBandPart<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_bool(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPEC);
|
||||
} // namespace functor
|
||||
|
|
@ -156,6 +157,7 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
|
|||
.HostMemory("num_upper"), \
|
||||
MatrixBandPartOp<GPUDevice, type>);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_BAND_PART_GPU);
|
||||
TF_CALL_bool(REGISTER_MATRIX_BAND_PART_GPU);
|
||||
TF_CALL_complex64(REGISTER_MATRIX_BAND_PART_GPU);
|
||||
TF_CALL_complex128(REGISTER_MATRIX_BAND_PART_GPU);
|
||||
#undef REGISTER_MATRIX_BAND_PART_GPU
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ typedef Eigen::GpuDevice GPUDevice;
|
|||
template struct functor::MatrixBandPart<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
|
||||
TF_CALL_bool(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex64(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex128(DEFINE_GPU_SPEC);
|
||||
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ class MatrixDiagOp : public OpKernel {
|
|||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MatrixDiagPart").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||
MatrixDiagPartOp<CPUDevice, type>);
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_MATRIX_DIAG);
|
||||
TF_CALL_POD_TYPES(REGISTER_MATRIX_DIAG);
|
||||
#undef REGISTER_MATRIX_DIAG
|
||||
|
||||
// Registration of the deprecated kernel.
|
||||
|
|
@ -136,7 +136,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_MATRIX_DIAG);
|
|||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T"), \
|
||||
MatrixDiagPartOp<CPUDevice, type>);
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_BATCH_MATRIX_DIAG);
|
||||
TF_CALL_POD_TYPES(REGISTER_BATCH_MATRIX_DIAG);
|
||||
#undef REGISTER_BATCH_MATRIX_DIAG
|
||||
|
||||
// Implementation of the functor specialization for CPU.
|
||||
|
|
@ -187,6 +187,7 @@ namespace functor {
|
|||
extern template struct MatrixDiagPart<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_bool(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPEC);
|
||||
|
||||
|
|
@ -201,6 +202,7 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
|
|||
Name("MatrixDiagPart").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
MatrixDiagPartOp<GPUDevice, type>);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_DIAG_GPU);
|
||||
TF_CALL_bool(REGISTER_MATRIX_DIAG_GPU);
|
||||
TF_CALL_complex64(REGISTER_MATRIX_DIAG_GPU);
|
||||
TF_CALL_complex128(REGISTER_MATRIX_DIAG_GPU);
|
||||
#undef REGISTER_MATRIX_DIAG_GPU
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ typedef Eigen::GpuDevice GPUDevice;
|
|||
template struct functor::MatrixDiagPart<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
|
||||
TF_CALL_bool(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex64(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex128(DEFINE_GPU_SPEC);
|
||||
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ class MatrixSetDiagOp : public OpKernel {
|
|||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MatrixSetDiag").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||
MatrixSetDiagOp<CPUDevice, type>);
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_MATRIX_SET_DIAG);
|
||||
TF_CALL_POD_TYPES(REGISTER_MATRIX_SET_DIAG);
|
||||
#undef REGISTER_MATRIX_SET_DIAG
|
||||
|
||||
// Registration of the deprecated kernel.
|
||||
|
|
@ -109,7 +109,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_MATRIX_SET_DIAG);
|
|||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BatchMatrixSetDiag").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||
MatrixSetDiagOp<CPUDevice, type>);
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_BATCH_MATRIX_SET_DIAG);
|
||||
TF_CALL_POD_TYPES(REGISTER_BATCH_MATRIX_SET_DIAG);
|
||||
#undef REGISTER_BATCH_MATRIX_SET_DIAG
|
||||
|
||||
namespace functor {
|
||||
|
|
@ -131,6 +131,21 @@ struct MatrixSetDiag<CPUDevice, T> {
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MatrixSetDiag<CPUDevice, bool> {
|
||||
static void Compute(const CPUDevice& d, TTypes<bool, 3>::ConstTensor input,
|
||||
TTypes<bool, 2>::ConstTensor diag,
|
||||
TTypes<bool>::Scalar scratch,
|
||||
TTypes<bool, 3>::Tensor output) {
|
||||
output.device(d) = input;
|
||||
for (int64 r = 0; r < output.dimension(0); ++r) {
|
||||
for (int64 d = 0; d < diag.dimension(1); ++d) {
|
||||
output(r, d, d) = diag(r, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
|
@ -147,6 +162,7 @@ namespace functor {
|
|||
extern template struct MatrixSetDiag<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_bool(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex64(DECLARE_GPU_SPEC);
|
||||
TF_CALL_complex128(DECLARE_GPU_SPEC);
|
||||
|
||||
|
|
@ -158,6 +174,7 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
|
|||
Name("MatrixSetDiag").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||
MatrixSetDiagOp<GPUDevice, type>);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_SET_DIAG_GPU);
|
||||
TF_CALL_bool(REGISTER_MATRIX_SET_DIAG_GPU);
|
||||
TF_CALL_complex64(REGISTER_MATRIX_SET_DIAG_GPU);
|
||||
TF_CALL_complex128(REGISTER_MATRIX_SET_DIAG_GPU);
|
||||
#undef REGISTER_MATRIX_SET_DIAG_GPU
|
||||
|
|
|
|||
|
|
@ -71,6 +71,23 @@ struct MatrixSetDiag {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename Device>
|
||||
struct MatrixSetDiag<Device, bool> {
|
||||
EIGEN_ALWAYS_INLINE static void Compute(const Device& d,
|
||||
TTypes<bool, 3>::ConstTensor input,
|
||||
TTypes<bool, 2>::ConstTensor diag,
|
||||
TTypes<bool>::Scalar scratch,
|
||||
TTypes<bool, 3>::Tensor output) {
|
||||
output.device(d) = input;
|
||||
generator::OverwriteDiagGenerator<bool> generator(diag, output);
|
||||
// Use all() to force the generation to aggregate to the scalar
|
||||
// output scratch. This in turn forces each element of the
|
||||
// generator to execute. The side effect of the execution is to
|
||||
// update the diagonal components of output with diag.
|
||||
scratch.device(d) = diag.generate(generator).all();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ typedef Eigen::GpuDevice GPUDevice;
|
|||
template struct functor::MatrixSetDiag<GPUDevice, T>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
|
||||
TF_CALL_bool(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex64(DEFINE_GPU_SPEC);
|
||||
TF_CALL_complex128(DEFINE_GPU_SPEC);
|
||||
|
||||
|
|
|
|||
|
|
@ -938,7 +938,7 @@ cuda_py_test(
|
|||
tags = ["notsan"],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
cuda_py_test(
|
||||
name = "diag_op_test",
|
||||
size = "medium",
|
||||
srcs = ["diag_op_test.py"],
|
||||
|
|
|
|||
|
|
@ -39,21 +39,23 @@ class MatrixDiagTest(test.TestCase):
|
|||
self.assertEqual((3, 3), v_diag.get_shape())
|
||||
self.assertAllEqual(v_diag.eval(), mat)
|
||||
|
||||
def testBatchVector(self):
|
||||
def _testBatchVector(self, dtype):
|
||||
with self.test_session(use_gpu=True):
|
||||
v_batch = np.array([[1.0, 2.0, 3.0],
|
||||
[4.0, 5.0, 6.0]])
|
||||
mat_batch = np.array(
|
||||
[[[1.0, 0.0, 0.0],
|
||||
[0.0, 2.0, 0.0],
|
||||
[0.0, 0.0, 3.0]],
|
||||
[[4.0, 0.0, 0.0],
|
||||
[0.0, 5.0, 0.0],
|
||||
[0.0, 0.0, 6.0]]])
|
||||
v_batch = np.array([[1.0, 0.0, 3.0], [4.0, 5.0, 6.0]]).astype(dtype)
|
||||
mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]],
|
||||
[[4.0, 0.0, 0.0], [0.0, 5.0, 0.0],
|
||||
[0.0, 0.0, 6.0]]]).astype(dtype)
|
||||
v_batch_diag = array_ops.matrix_diag(v_batch)
|
||||
self.assertEqual((2, 3, 3), v_batch_diag.get_shape())
|
||||
self.assertAllEqual(v_batch_diag.eval(), mat_batch)
|
||||
|
||||
def testBatchVector(self):
|
||||
self._testBatchVector(np.float32)
|
||||
self._testBatchVector(np.float64)
|
||||
self._testBatchVector(np.int32)
|
||||
self._testBatchVector(np.int64)
|
||||
self._testBatchVector(np.bool)
|
||||
|
||||
def testInvalidShape(self):
|
||||
with self.assertRaisesRegexp(ValueError, "must be at least rank 1"):
|
||||
array_ops.matrix_diag(0)
|
||||
|
|
@ -108,29 +110,29 @@ class MatrixSetDiagTest(test.TestCase):
|
|||
self.assertEqual((3, 2), output.get_shape())
|
||||
self.assertAllEqual(expected, output.eval())
|
||||
|
||||
def testSquareBatch(self):
|
||||
def _testSquareBatch(self, dtype):
|
||||
with self.test_session(use_gpu=True):
|
||||
v_batch = np.array([[-1.0, -2.0, -3.0],
|
||||
[-4.0, -5.0, -6.0]])
|
||||
mat_batch = np.array(
|
||||
[[[1.0, 0.0, 3.0],
|
||||
[0.0, 2.0, 0.0],
|
||||
[1.0, 0.0, 3.0]],
|
||||
[[4.0, 0.0, 4.0],
|
||||
[0.0, 5.0, 0.0],
|
||||
[2.0, 0.0, 6.0]]])
|
||||
v_batch = np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]]).astype(dtype)
|
||||
mat_batch = np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]],
|
||||
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0],
|
||||
[2.0, 0.0, 6.0]]]).astype(dtype)
|
||||
|
||||
mat_set_diag_batch = np.array([[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0],
|
||||
[1.0, 0.0, -3.0]],
|
||||
[[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0],
|
||||
[2.0, 0.0, -6.0]]]).astype(dtype)
|
||||
|
||||
mat_set_diag_batch = np.array(
|
||||
[[[-1.0, 0.0, 3.0],
|
||||
[0.0, -2.0, 0.0],
|
||||
[1.0, 0.0, -3.0]],
|
||||
[[-4.0, 0.0, 4.0],
|
||||
[0.0, -5.0, 0.0],
|
||||
[2.0, 0.0, -6.0]]])
|
||||
output = array_ops.matrix_set_diag(mat_batch, v_batch)
|
||||
self.assertEqual((2, 3, 3), output.get_shape())
|
||||
self.assertAllEqual(mat_set_diag_batch, output.eval())
|
||||
|
||||
def testSquareBatch(self):
|
||||
self._testSquareBatch(np.float32)
|
||||
self._testSquareBatch(np.float64)
|
||||
self._testSquareBatch(np.int32)
|
||||
self._testSquareBatch(np.int64)
|
||||
self._testSquareBatch(np.bool)
|
||||
|
||||
def testRectangularBatch(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
v_batch = np.array([[-1.0, -2.0],
|
||||
|
|
@ -220,22 +222,24 @@ class MatrixDiagPartTest(test.TestCase):
|
|||
mat_diag = array_ops.matrix_diag_part(mat)
|
||||
self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0]))
|
||||
|
||||
def testSquareBatch(self):
|
||||
def _testSquareBatch(self, dtype):
|
||||
with self.test_session(use_gpu=True):
|
||||
v_batch = np.array([[1.0, 2.0, 3.0],
|
||||
[4.0, 5.0, 6.0]])
|
||||
mat_batch = np.array(
|
||||
[[[1.0, 0.0, 0.0],
|
||||
[0.0, 2.0, 0.0],
|
||||
[0.0, 0.0, 3.0]],
|
||||
[[4.0, 0.0, 0.0],
|
||||
[0.0, 5.0, 0.0],
|
||||
[0.0, 0.0, 6.0]]])
|
||||
v_batch = np.array([[1.0, 0.0, 3.0], [4.0, 5.0, 6.0]]).astype(dtype)
|
||||
mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]],
|
||||
[[4.0, 0.0, 0.0], [0.0, 5.0, 0.0],
|
||||
[0.0, 0.0, 6.0]]]).astype(dtype)
|
||||
self.assertEqual(mat_batch.shape, (2, 3, 3))
|
||||
mat_batch_diag = array_ops.matrix_diag_part(mat_batch)
|
||||
self.assertEqual((2, 3), mat_batch_diag.get_shape())
|
||||
self.assertAllEqual(mat_batch_diag.eval(), v_batch)
|
||||
|
||||
def testSquareBatch(self):
|
||||
self._testSquareBatch(np.float32)
|
||||
self._testSquareBatch(np.float64)
|
||||
self._testSquareBatch(np.int32)
|
||||
self._testSquareBatch(np.int64)
|
||||
self._testSquareBatch(np.bool)
|
||||
|
||||
def testRectangularBatch(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
v_batch = np.array([[1.0, 2.0],
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user