Make tf.eye accept Python integer shapes and avoid generating unnecessary shape handling ops.

Clean up test and add tests with placeholders.

PiperOrigin-RevId: 165746090
This commit is contained in:
A. Unique TensorFlower 2017-08-18 13:38:19 -07:00 committed by TensorFlower Gardener
parent 109ecf823d
commit 378463ae89
2 changed files with 91 additions and 190 deletions

View File

@ -27,7 +27,14 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
def _random_pd_matrix(n, rng):
def _AddTest(test_class, op_name, testcase_name, fn):
test_name = "_".join(["test", op_name, testcase_name])
if hasattr(test_class, test_name):
raise RuntimeError("Test %s defined more than once" % test_name)
setattr(test_class, test_name, fn)
def _RandomPDMatrix(n, rng):
"""Random positive definite matrix."""
temp = rng.randn(n, n)
return temp.dot(temp.T)
@ -44,8 +51,8 @@ class CholeskySolveTest(test.TestCase):
for np_type, atol in [(np.float32, 0.05), (np.float64, 1e-5)]:
# Create 2 x n x n matrix
array = np.array(
[_random_pd_matrix(n, self.rng), _random_pd_matrix(n, self.rng)
]).astype(np_type)
[_RandomPDMatrix(n, self.rng),
_RandomPDMatrix(n, self.rng)]).astype(np_type)
chol = linalg_ops.cholesky(array)
for k in range(1, 3):
rhs = self.rng.randn(2, n, k).astype(np_type)
@ -55,174 +62,58 @@ class CholeskySolveTest(test.TestCase):
class EyeTest(test.TestCase):
def test_non_batch_2x2(self):
num_rows = 2
dtype = np.float32
np_eye = np.eye(num_rows).astype(dtype)
with self.test_session(use_gpu=True):
eye = linalg_ops.eye(num_rows, dtype=dtype)
self.assertAllEqual((num_rows, num_rows), eye.get_shape())
self.assertAllEqual(np_eye, eye.eval())
def test_non_batch_2x3(self):
num_rows = 2
num_columns = 3
dtype = np.float32
np_eye = np.eye(num_rows, num_columns).astype(dtype)
with self.test_session(use_gpu=True):
eye = linalg_ops.eye(num_rows, num_columns=num_columns, dtype=dtype)
self.assertAllEqual((num_rows, num_columns), eye.get_shape())
self.assertAllEqual(np_eye, eye.eval())
def test_1x3_batch_4x4(self):
num_rows = 4
batch_shape = [1, 3]
dtype = np.float32
np_eye = np.eye(num_rows).astype(dtype)
with self.test_session(use_gpu=True):
eye = linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=dtype)
self.assertAllEqual(batch_shape + [num_rows, num_rows], eye.get_shape())
eye_v = eye.eval()
for i in range(batch_shape[0]):
for j in range(batch_shape[1]):
self.assertAllEqual(np_eye, eye_v[i, j, :, :])
def test_1x3_batch_4x4_dynamic(self):
num_rows = 4
batch_shape = [1, 3]
dtype = np.float32
np_eye = np.eye(num_rows).astype(dtype)
with self.test_session(use_gpu=True):
num_rows_ph = array_ops.placeholder(dtypes.int32)
batch_shape_ph = array_ops.placeholder(dtypes.int32)
eye = linalg_ops.eye(num_rows_ph, batch_shape=batch_shape_ph, dtype=dtype)
eye_v = eye.eval(
feed_dict={num_rows_ph: num_rows,
batch_shape_ph: batch_shape})
for i in range(batch_shape[0]):
for j in range(batch_shape[1]):
self.assertAllEqual(np_eye, eye_v[i, j, :, :])
def test_1x3_batch_5x4(self):
num_rows = 5
num_columns = 4
batch_shape = [1, 3]
dtype = np.float32
np_eye = np.eye(num_rows, num_columns).astype(dtype)
with self.test_session(use_gpu=True):
eye = linalg_ops.eye(num_rows,
num_columns=num_columns,
batch_shape=batch_shape,
dtype=dtype)
self.assertAllEqual(batch_shape + [num_rows, num_columns],
eye.get_shape())
eye_v = eye.eval()
for i in range(batch_shape[0]):
for j in range(batch_shape[1]):
self.assertAllEqual(np_eye, eye_v[i, j, :, :])
def test_1x3_batch_5x4_dynamic(self):
num_rows = 5
num_columns = 4
batch_shape = [1, 3]
dtype = np.float32
np_eye = np.eye(num_rows, num_columns).astype(dtype)
with self.test_session(use_gpu=True):
num_rows_ph = array_ops.placeholder(dtypes.int32)
num_columns_ph = array_ops.placeholder(dtypes.int32)
batch_shape_ph = array_ops.placeholder(dtypes.int32)
eye = linalg_ops.eye(num_rows_ph,
num_columns=num_columns_ph,
batch_shape=batch_shape_ph,
dtype=dtype)
eye_v = eye.eval(feed_dict={
num_rows_ph: num_rows,
num_columns_ph: num_columns,
batch_shape_ph: batch_shape
})
for i in range(batch_shape[0]):
for j in range(batch_shape[1]):
self.assertAllEqual(np_eye, eye_v[i, j, :, :])
def test_non_batch_0x0(self):
num_rows = 0
dtype = np.int64
np_eye = np.eye(num_rows).astype(dtype)
with self.test_session(use_gpu=True):
eye = linalg_ops.eye(num_rows, dtype=dtype)
self.assertAllEqual((num_rows, num_rows), eye.get_shape())
self.assertAllEqual(np_eye, eye.eval())
def test_non_batch_2x0(self):
num_rows = 2
num_columns = 0
dtype = np.int64
np_eye = np.eye(num_rows, num_columns).astype(dtype)
with self.test_session(use_gpu=True):
eye = linalg_ops.eye(num_rows, num_columns=num_columns, dtype=dtype)
self.assertAllEqual((num_rows, num_columns), eye.get_shape())
self.assertAllEqual(np_eye, eye.eval())
def test_non_batch_0x2(self):
num_rows = 0
num_columns = 2
dtype = np.int64
np_eye = np.eye(num_rows, num_columns).astype(dtype)
with self.test_session(use_gpu=True):
eye = linalg_ops.eye(num_rows, num_columns=num_columns, dtype=dtype)
self.assertAllEqual((num_rows, num_columns), eye.get_shape())
self.assertAllEqual(np_eye, eye.eval())
def test_1x3_batch_0x0(self):
num_rows = 0
batch_shape = [1, 3]
dtype = np.float32
np_eye = np.eye(num_rows).astype(dtype)
with self.test_session(use_gpu=True):
eye = linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=dtype)
self.assertAllEqual((1, 3, 0, 0), eye.get_shape())
eye_v = eye.eval()
for i in range(batch_shape[0]):
for j in range(batch_shape[1]):
self.assertAllEqual(np_eye, eye_v[i, j, :, :])
def test_1x3_batch_2x0(self):
num_rows = 2
num_columns = 0
batch_shape = [1, 3]
dtype = np.float32
np_eye = np.eye(num_rows, num_columns).astype(dtype)
with self.test_session(use_gpu=True):
eye = linalg_ops.eye(num_rows,
num_columns=num_columns,
batch_shape=batch_shape,
dtype=dtype)
self.assertAllEqual(batch_shape + [num_rows, num_columns],
eye.get_shape())
eye_v = eye.eval()
for i in range(batch_shape[0]):
for j in range(batch_shape[1]):
self.assertAllEqual(np_eye, eye_v[i, j, :, :])
def test_1x3_batch_0x2(self):
num_rows = 0
num_columns = 2
batch_shape = [1, 3]
dtype = np.float32
np_eye = np.eye(num_rows, num_columns).astype(dtype)
with self.test_session(use_gpu=True):
eye = linalg_ops.eye(num_rows,
num_columns=num_columns,
batch_shape=batch_shape,
dtype=dtype)
self.assertAllEqual(batch_shape + [num_rows, num_columns],
eye.get_shape())
eye_v = eye.eval()
for i in range(batch_shape[0]):
for j in range(batch_shape[1]):
self.assertAllEqual(np_eye, eye_v[i, j, :, :])
pass # Will be filled in below
if __name__ == '__main__':
def _GetEyeTest(num_rows, num_columns, batch_shape, dtype):
def Test(self):
eye_np = np.eye(num_rows, M=num_columns, dtype=dtype.as_numpy_dtype)
if batch_shape is not None:
eye_np = np.tile(eye_np, batch_shape + [1, 1])
for use_placeholder in False, True:
if use_placeholder and (num_columns is None or batch_shape is None):
return
with self.test_session(use_gpu=True) as sess:
if use_placeholder:
num_rows_placeholder = array_ops.placeholder(
dtypes.int32, name="num_rows")
num_columns_placeholder = array_ops.placeholder(
dtypes.int32, name="num_columns")
batch_shape_placeholder = array_ops.placeholder(
dtypes.int32, name="batch_shape")
eye = linalg_ops.eye(
num_rows_placeholder,
num_columns=num_columns_placeholder,
batch_shape=batch_shape_placeholder,
dtype=dtype)
eye_tf = sess.run(
eye,
feed_dict={
num_rows_placeholder: num_rows,
num_columns_placeholder: num_columns,
batch_shape_placeholder: batch_shape
})
else:
eye_tf = linalg_ops.eye(
num_rows,
num_columns=num_columns,
batch_shape=batch_shape,
dtype=dtype).eval()
self.assertAllEqual(eye_np, eye_tf)
return Test
if __name__ == "__main__":
for _num_rows in 0, 1, 2, 5:
for _num_columns in None, 0, 1, 2, 5:
for _batch_shape in None, [], [2], [2, 3]:
for _dtype in (dtypes.int32, dtypes.int64, dtypes.float32,
dtypes.float64, dtypes.complex64, dtypes.complex128):
name = "dtype_%s_num_rows_%s_num_column_%s_batch_shape_%s_" % (
_dtype.name, _num_rows, _num_columns, _batch_shape)
_AddTest(EyeTest, "EyeTest", name,
_GetEyeTest(_num_rows, _num_columns, _batch_shape, _dtype))
test.main()

View File

@ -25,11 +25,10 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_linalg_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.util import compat
# Names below are lower_case.
# pylint: disable=invalid-name
@ -105,8 +104,9 @@ def eye(num_rows,
in each batch matrix.
num_columns: Optional non-negative `int32` scalar `Tensor` giving the number
of columns in each batch matrix. Defaults to `num_rows`.
batch_shape: `int32` `Tensor`. If provided, returned `Tensor` will have
leading batch dimensions of this shape.
batch_shape: A list or tuple of Python integers or a 1-D `int32` `Tensor`.
If provided, the returned `Tensor` will have leading batch dimensions of
this shape.
dtype: The type of an element in the resulting `Tensor`
name: A name for this `Op`. Defaults to "eye".
@ -115,22 +115,32 @@ def eye(num_rows,
"""
with ops.name_scope(
name, default_name='eye', values=[num_rows, num_columns, batch_shape]):
is_square = num_columns is None
batch_shape = [] if batch_shape is None else batch_shape
batch_shape = ops.convert_to_tensor(
batch_shape, name='shape', dtype=dtypes.int32)
if num_columns is None:
diag_size = num_rows
else:
num_columns = num_rows if num_columns is None else num_columns
if isinstance(num_rows, ops.Tensor) or isinstance(
num_columns, ops.Tensor) or isinstance(batch_shape, ops.Tensor):
batch_shape = ops.convert_to_tensor(
batch_shape, name='shape', dtype=dtypes.int32)
diag_size = math_ops.minimum(num_rows, num_columns)
diag_shape = array_ops.concat((batch_shape, [diag_size]), 0)
diag_ones = array_ops.ones(diag_shape, dtype=dtype)
diag_shape = array_ops.concat((batch_shape, [diag_size]), 0)
if not is_square:
shape = array_ops.concat((batch_shape, [num_rows, num_columns]), 0)
else:
if not isinstance(num_rows, compat.integral_types) or not isinstance(
num_columns, compat.integral_types):
raise TypeError(
'num_rows and num_columns must be positive integer values.')
batch_shape = [dim for dim in batch_shape]
is_square = num_rows == num_columns
diag_shape = batch_shape + [np.minimum(num_rows, num_columns)]
if not is_square:
shape = batch_shape + [num_rows, num_columns]
if num_columns is None:
diag_ones = array_ops.ones(diag_shape, dtype=dtype)
if is_square:
return array_ops.matrix_diag(diag_ones)
else:
shape = array_ops.concat((batch_shape, [num_rows, num_columns]), 0)
zero_matrix = array_ops.zeros(shape, dtype=dtype)
return array_ops.matrix_set_diag(zero_matrix, diag_ones)
@ -140,7 +150,7 @@ def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
`matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions
form `M`-by-`N` matrices. Rhs is a tensor of shape `[..., M, K]` whose
inner-most 2 dimensions form `M`-by-`K` matrices. The computed output is a
inner-most 2 dimensions form `M`-by-`K` matrices. The computed output is a
`Tensor` of shape `[..., N, K]` whose inner-most 2 dimensions form `M`-by-`K`
matrices that solve the equations
`matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]` in the least squares
@ -389,9 +399,9 @@ def norm(tensor, ord='euclidean', axis=None, keep_dims=False, name=None):
result = math_ops.reduce_max(result, max_axis, keep_dims=True)
else:
# General p-norms (positive p only)
result = math_ops.pow(math_ops.reduce_sum(
math_ops.pow(result, ord), axis, keep_dims=True),
1.0 / ord)
result = math_ops.pow(
math_ops.reduce_sum(
math_ops.pow(result, ord), axis, keep_dims=True), 1.0 / ord)
if not keep_dims:
result = array_ops.squeeze(result, axis)
return result