mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Make tf.eye accept Python integer shapes and avoid generating unnecessary shape handling ops.
Clean up test and add tests with placeholders. PiperOrigin-RevId: 165746090
This commit is contained in:
parent
109ecf823d
commit
378463ae89
|
|
@ -27,7 +27,14 @@ from tensorflow.python.ops import math_ops
|
|||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _random_pd_matrix(n, rng):
|
||||
def _AddTest(test_class, op_name, testcase_name, fn):
|
||||
test_name = "_".join(["test", op_name, testcase_name])
|
||||
if hasattr(test_class, test_name):
|
||||
raise RuntimeError("Test %s defined more than once" % test_name)
|
||||
setattr(test_class, test_name, fn)
|
||||
|
||||
|
||||
def _RandomPDMatrix(n, rng):
|
||||
"""Random positive definite matrix."""
|
||||
temp = rng.randn(n, n)
|
||||
return temp.dot(temp.T)
|
||||
|
|
@ -44,8 +51,8 @@ class CholeskySolveTest(test.TestCase):
|
|||
for np_type, atol in [(np.float32, 0.05), (np.float64, 1e-5)]:
|
||||
# Create 2 x n x n matrix
|
||||
array = np.array(
|
||||
[_random_pd_matrix(n, self.rng), _random_pd_matrix(n, self.rng)
|
||||
]).astype(np_type)
|
||||
[_RandomPDMatrix(n, self.rng),
|
||||
_RandomPDMatrix(n, self.rng)]).astype(np_type)
|
||||
chol = linalg_ops.cholesky(array)
|
||||
for k in range(1, 3):
|
||||
rhs = self.rng.randn(2, n, k).astype(np_type)
|
||||
|
|
@ -55,174 +62,58 @@ class CholeskySolveTest(test.TestCase):
|
|||
|
||||
|
||||
class EyeTest(test.TestCase):
|
||||
|
||||
def test_non_batch_2x2(self):
|
||||
num_rows = 2
|
||||
dtype = np.float32
|
||||
np_eye = np.eye(num_rows).astype(dtype)
|
||||
with self.test_session(use_gpu=True):
|
||||
eye = linalg_ops.eye(num_rows, dtype=dtype)
|
||||
self.assertAllEqual((num_rows, num_rows), eye.get_shape())
|
||||
self.assertAllEqual(np_eye, eye.eval())
|
||||
|
||||
def test_non_batch_2x3(self):
|
||||
num_rows = 2
|
||||
num_columns = 3
|
||||
dtype = np.float32
|
||||
np_eye = np.eye(num_rows, num_columns).astype(dtype)
|
||||
with self.test_session(use_gpu=True):
|
||||
eye = linalg_ops.eye(num_rows, num_columns=num_columns, dtype=dtype)
|
||||
self.assertAllEqual((num_rows, num_columns), eye.get_shape())
|
||||
self.assertAllEqual(np_eye, eye.eval())
|
||||
|
||||
def test_1x3_batch_4x4(self):
|
||||
num_rows = 4
|
||||
batch_shape = [1, 3]
|
||||
dtype = np.float32
|
||||
np_eye = np.eye(num_rows).astype(dtype)
|
||||
with self.test_session(use_gpu=True):
|
||||
eye = linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=dtype)
|
||||
self.assertAllEqual(batch_shape + [num_rows, num_rows], eye.get_shape())
|
||||
eye_v = eye.eval()
|
||||
for i in range(batch_shape[0]):
|
||||
for j in range(batch_shape[1]):
|
||||
self.assertAllEqual(np_eye, eye_v[i, j, :, :])
|
||||
|
||||
def test_1x3_batch_4x4_dynamic(self):
|
||||
num_rows = 4
|
||||
batch_shape = [1, 3]
|
||||
dtype = np.float32
|
||||
np_eye = np.eye(num_rows).astype(dtype)
|
||||
with self.test_session(use_gpu=True):
|
||||
num_rows_ph = array_ops.placeholder(dtypes.int32)
|
||||
batch_shape_ph = array_ops.placeholder(dtypes.int32)
|
||||
eye = linalg_ops.eye(num_rows_ph, batch_shape=batch_shape_ph, dtype=dtype)
|
||||
eye_v = eye.eval(
|
||||
feed_dict={num_rows_ph: num_rows,
|
||||
batch_shape_ph: batch_shape})
|
||||
for i in range(batch_shape[0]):
|
||||
for j in range(batch_shape[1]):
|
||||
self.assertAllEqual(np_eye, eye_v[i, j, :, :])
|
||||
|
||||
def test_1x3_batch_5x4(self):
|
||||
num_rows = 5
|
||||
num_columns = 4
|
||||
batch_shape = [1, 3]
|
||||
dtype = np.float32
|
||||
np_eye = np.eye(num_rows, num_columns).astype(dtype)
|
||||
with self.test_session(use_gpu=True):
|
||||
eye = linalg_ops.eye(num_rows,
|
||||
num_columns=num_columns,
|
||||
batch_shape=batch_shape,
|
||||
dtype=dtype)
|
||||
self.assertAllEqual(batch_shape + [num_rows, num_columns],
|
||||
eye.get_shape())
|
||||
eye_v = eye.eval()
|
||||
for i in range(batch_shape[0]):
|
||||
for j in range(batch_shape[1]):
|
||||
self.assertAllEqual(np_eye, eye_v[i, j, :, :])
|
||||
|
||||
def test_1x3_batch_5x4_dynamic(self):
|
||||
num_rows = 5
|
||||
num_columns = 4
|
||||
batch_shape = [1, 3]
|
||||
dtype = np.float32
|
||||
np_eye = np.eye(num_rows, num_columns).astype(dtype)
|
||||
with self.test_session(use_gpu=True):
|
||||
num_rows_ph = array_ops.placeholder(dtypes.int32)
|
||||
num_columns_ph = array_ops.placeholder(dtypes.int32)
|
||||
batch_shape_ph = array_ops.placeholder(dtypes.int32)
|
||||
eye = linalg_ops.eye(num_rows_ph,
|
||||
num_columns=num_columns_ph,
|
||||
batch_shape=batch_shape_ph,
|
||||
dtype=dtype)
|
||||
eye_v = eye.eval(feed_dict={
|
||||
num_rows_ph: num_rows,
|
||||
num_columns_ph: num_columns,
|
||||
batch_shape_ph: batch_shape
|
||||
})
|
||||
for i in range(batch_shape[0]):
|
||||
for j in range(batch_shape[1]):
|
||||
self.assertAllEqual(np_eye, eye_v[i, j, :, :])
|
||||
|
||||
def test_non_batch_0x0(self):
|
||||
num_rows = 0
|
||||
dtype = np.int64
|
||||
np_eye = np.eye(num_rows).astype(dtype)
|
||||
with self.test_session(use_gpu=True):
|
||||
eye = linalg_ops.eye(num_rows, dtype=dtype)
|
||||
self.assertAllEqual((num_rows, num_rows), eye.get_shape())
|
||||
self.assertAllEqual(np_eye, eye.eval())
|
||||
|
||||
def test_non_batch_2x0(self):
|
||||
num_rows = 2
|
||||
num_columns = 0
|
||||
dtype = np.int64
|
||||
np_eye = np.eye(num_rows, num_columns).astype(dtype)
|
||||
with self.test_session(use_gpu=True):
|
||||
eye = linalg_ops.eye(num_rows, num_columns=num_columns, dtype=dtype)
|
||||
self.assertAllEqual((num_rows, num_columns), eye.get_shape())
|
||||
self.assertAllEqual(np_eye, eye.eval())
|
||||
|
||||
def test_non_batch_0x2(self):
|
||||
num_rows = 0
|
||||
num_columns = 2
|
||||
dtype = np.int64
|
||||
np_eye = np.eye(num_rows, num_columns).astype(dtype)
|
||||
with self.test_session(use_gpu=True):
|
||||
eye = linalg_ops.eye(num_rows, num_columns=num_columns, dtype=dtype)
|
||||
self.assertAllEqual((num_rows, num_columns), eye.get_shape())
|
||||
self.assertAllEqual(np_eye, eye.eval())
|
||||
|
||||
def test_1x3_batch_0x0(self):
|
||||
num_rows = 0
|
||||
batch_shape = [1, 3]
|
||||
dtype = np.float32
|
||||
np_eye = np.eye(num_rows).astype(dtype)
|
||||
with self.test_session(use_gpu=True):
|
||||
eye = linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=dtype)
|
||||
self.assertAllEqual((1, 3, 0, 0), eye.get_shape())
|
||||
eye_v = eye.eval()
|
||||
for i in range(batch_shape[0]):
|
||||
for j in range(batch_shape[1]):
|
||||
self.assertAllEqual(np_eye, eye_v[i, j, :, :])
|
||||
|
||||
def test_1x3_batch_2x0(self):
|
||||
num_rows = 2
|
||||
num_columns = 0
|
||||
batch_shape = [1, 3]
|
||||
dtype = np.float32
|
||||
np_eye = np.eye(num_rows, num_columns).astype(dtype)
|
||||
with self.test_session(use_gpu=True):
|
||||
eye = linalg_ops.eye(num_rows,
|
||||
num_columns=num_columns,
|
||||
batch_shape=batch_shape,
|
||||
dtype=dtype)
|
||||
self.assertAllEqual(batch_shape + [num_rows, num_columns],
|
||||
eye.get_shape())
|
||||
eye_v = eye.eval()
|
||||
for i in range(batch_shape[0]):
|
||||
for j in range(batch_shape[1]):
|
||||
self.assertAllEqual(np_eye, eye_v[i, j, :, :])
|
||||
|
||||
def test_1x3_batch_0x2(self):
|
||||
num_rows = 0
|
||||
num_columns = 2
|
||||
batch_shape = [1, 3]
|
||||
dtype = np.float32
|
||||
np_eye = np.eye(num_rows, num_columns).astype(dtype)
|
||||
with self.test_session(use_gpu=True):
|
||||
eye = linalg_ops.eye(num_rows,
|
||||
num_columns=num_columns,
|
||||
batch_shape=batch_shape,
|
||||
dtype=dtype)
|
||||
self.assertAllEqual(batch_shape + [num_rows, num_columns],
|
||||
eye.get_shape())
|
||||
eye_v = eye.eval()
|
||||
for i in range(batch_shape[0]):
|
||||
for j in range(batch_shape[1]):
|
||||
self.assertAllEqual(np_eye, eye_v[i, j, :, :])
|
||||
pass # Will be filled in below
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def _GetEyeTest(num_rows, num_columns, batch_shape, dtype):
|
||||
|
||||
def Test(self):
|
||||
eye_np = np.eye(num_rows, M=num_columns, dtype=dtype.as_numpy_dtype)
|
||||
if batch_shape is not None:
|
||||
eye_np = np.tile(eye_np, batch_shape + [1, 1])
|
||||
for use_placeholder in False, True:
|
||||
if use_placeholder and (num_columns is None or batch_shape is None):
|
||||
return
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
if use_placeholder:
|
||||
num_rows_placeholder = array_ops.placeholder(
|
||||
dtypes.int32, name="num_rows")
|
||||
num_columns_placeholder = array_ops.placeholder(
|
||||
dtypes.int32, name="num_columns")
|
||||
batch_shape_placeholder = array_ops.placeholder(
|
||||
dtypes.int32, name="batch_shape")
|
||||
eye = linalg_ops.eye(
|
||||
num_rows_placeholder,
|
||||
num_columns=num_columns_placeholder,
|
||||
batch_shape=batch_shape_placeholder,
|
||||
dtype=dtype)
|
||||
eye_tf = sess.run(
|
||||
eye,
|
||||
feed_dict={
|
||||
num_rows_placeholder: num_rows,
|
||||
num_columns_placeholder: num_columns,
|
||||
batch_shape_placeholder: batch_shape
|
||||
})
|
||||
else:
|
||||
eye_tf = linalg_ops.eye(
|
||||
num_rows,
|
||||
num_columns=num_columns,
|
||||
batch_shape=batch_shape,
|
||||
dtype=dtype).eval()
|
||||
self.assertAllEqual(eye_np, eye_tf)
|
||||
|
||||
return Test
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for _num_rows in 0, 1, 2, 5:
|
||||
for _num_columns in None, 0, 1, 2, 5:
|
||||
for _batch_shape in None, [], [2], [2, 3]:
|
||||
for _dtype in (dtypes.int32, dtypes.int64, dtypes.float32,
|
||||
dtypes.float64, dtypes.complex64, dtypes.complex128):
|
||||
name = "dtype_%s_num_rows_%s_num_column_%s_batch_shape_%s_" % (
|
||||
_dtype.name, _num_rows, _num_columns, _batch_shape)
|
||||
_AddTest(EyeTest, "EyeTest", name,
|
||||
_GetEyeTest(_num_rows, _num_columns, _batch_shape, _dtype))
|
||||
|
||||
test.main()
|
||||
|
|
|
|||
|
|
@ -25,11 +25,10 @@ from tensorflow.python.framework import ops
|
|||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python.ops.gen_linalg_ops import *
|
||||
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
# Names below are lower_case.
|
||||
# pylint: disable=invalid-name
|
||||
|
|
@ -105,8 +104,9 @@ def eye(num_rows,
|
|||
in each batch matrix.
|
||||
num_columns: Optional non-negative `int32` scalar `Tensor` giving the number
|
||||
of columns in each batch matrix. Defaults to `num_rows`.
|
||||
batch_shape: `int32` `Tensor`. If provided, returned `Tensor` will have
|
||||
leading batch dimensions of this shape.
|
||||
batch_shape: A list or tuple of Python integers or a 1-D `int32` `Tensor`.
|
||||
If provided, the returned `Tensor` will have leading batch dimensions of
|
||||
this shape.
|
||||
dtype: The type of an element in the resulting `Tensor`
|
||||
name: A name for this `Op`. Defaults to "eye".
|
||||
|
||||
|
|
@ -115,22 +115,32 @@ def eye(num_rows,
|
|||
"""
|
||||
with ops.name_scope(
|
||||
name, default_name='eye', values=[num_rows, num_columns, batch_shape]):
|
||||
|
||||
is_square = num_columns is None
|
||||
batch_shape = [] if batch_shape is None else batch_shape
|
||||
batch_shape = ops.convert_to_tensor(
|
||||
batch_shape, name='shape', dtype=dtypes.int32)
|
||||
|
||||
if num_columns is None:
|
||||
diag_size = num_rows
|
||||
else:
|
||||
num_columns = num_rows if num_columns is None else num_columns
|
||||
if isinstance(num_rows, ops.Tensor) or isinstance(
|
||||
num_columns, ops.Tensor) or isinstance(batch_shape, ops.Tensor):
|
||||
batch_shape = ops.convert_to_tensor(
|
||||
batch_shape, name='shape', dtype=dtypes.int32)
|
||||
diag_size = math_ops.minimum(num_rows, num_columns)
|
||||
diag_shape = array_ops.concat((batch_shape, [diag_size]), 0)
|
||||
diag_ones = array_ops.ones(diag_shape, dtype=dtype)
|
||||
diag_shape = array_ops.concat((batch_shape, [diag_size]), 0)
|
||||
if not is_square:
|
||||
shape = array_ops.concat((batch_shape, [num_rows, num_columns]), 0)
|
||||
else:
|
||||
if not isinstance(num_rows, compat.integral_types) or not isinstance(
|
||||
num_columns, compat.integral_types):
|
||||
raise TypeError(
|
||||
'num_rows and num_columns must be positive integer values.')
|
||||
batch_shape = [dim for dim in batch_shape]
|
||||
is_square = num_rows == num_columns
|
||||
diag_shape = batch_shape + [np.minimum(num_rows, num_columns)]
|
||||
if not is_square:
|
||||
shape = batch_shape + [num_rows, num_columns]
|
||||
|
||||
if num_columns is None:
|
||||
diag_ones = array_ops.ones(diag_shape, dtype=dtype)
|
||||
if is_square:
|
||||
return array_ops.matrix_diag(diag_ones)
|
||||
else:
|
||||
shape = array_ops.concat((batch_shape, [num_rows, num_columns]), 0)
|
||||
zero_matrix = array_ops.zeros(shape, dtype=dtype)
|
||||
return array_ops.matrix_set_diag(zero_matrix, diag_ones)
|
||||
|
||||
|
|
@ -140,7 +150,7 @@ def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
|
|||
|
||||
`matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions
|
||||
form `M`-by-`N` matrices. Rhs is a tensor of shape `[..., M, K]` whose
|
||||
inner-most 2 dimensions form `M`-by-`K` matrices. The computed output is a
|
||||
inner-most 2 dimensions form `M`-by-`K` matrices. The computed output is a
|
||||
`Tensor` of shape `[..., N, K]` whose inner-most 2 dimensions form `M`-by-`K`
|
||||
matrices that solve the equations
|
||||
`matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]` in the least squares
|
||||
|
|
@ -389,9 +399,9 @@ def norm(tensor, ord='euclidean', axis=None, keep_dims=False, name=None):
|
|||
result = math_ops.reduce_max(result, max_axis, keep_dims=True)
|
||||
else:
|
||||
# General p-norms (positive p only)
|
||||
result = math_ops.pow(math_ops.reduce_sum(
|
||||
math_ops.pow(result, ord), axis, keep_dims=True),
|
||||
1.0 / ord)
|
||||
result = math_ops.pow(
|
||||
math_ops.reduce_sum(
|
||||
math_ops.pow(result, ord), axis, keep_dims=True), 1.0 / ord)
|
||||
if not keep_dims:
|
||||
result = array_ops.squeeze(result, axis)
|
||||
return result
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user