mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Fix numpy 1.13 incompatibilities (#10501)
* Fix numpy 1.13 incompatibilities * Skip tests with numpy 1.13.0
This commit is contained in:
parent
4572c41df0
commit
7c46214abb
|
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
|
@ -587,6 +588,7 @@ class SparseReduceSumTest(test_util.TensorFlowTestCase):
|
|||
self._compare(sp_t, reduction_axes, ndims, False)
|
||||
self._compare(sp_t, reduction_axes, ndims, True)
|
||||
|
||||
@unittest.skipIf(np.__version__ == "1.13.0", "numpy 1.13 bug")
|
||||
def testSimpleAndRandomInputs(self):
|
||||
sp_t = sparse_tensor.SparseTensor(self.ind, self.vals, self.dense_shape)
|
||||
|
||||
|
|
@ -619,6 +621,7 @@ class SparseReduceSumTest(test_util.TensorFlowTestCase):
|
|||
with self.assertRaisesOpError("Invalid reduction dimension 2"):
|
||||
sparse_ops.sparse_reduce_sum(sp_t, 2).eval()
|
||||
|
||||
@unittest.skipIf(np.__version__ == "1.13.0", "numpy 1.13 bug")
|
||||
def testGradient(self):
|
||||
np.random.seed(8161)
|
||||
test_dims = [(11, 1, 5, 7, 1), (2, 2)]
|
||||
|
|
|
|||
|
|
@ -726,7 +726,7 @@ def _assert_ranks_condition(
|
|||
|
||||
# Attempt to statically defined rank.
|
||||
ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks])
|
||||
if None not in ranks_static:
|
||||
if not any(r is None for r in ranks_static):
|
||||
for rank_static in ranks_static:
|
||||
if rank_static.ndim != 0:
|
||||
raise ValueError('Rank must be a scalar.')
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
|
|||
Raises:
|
||||
ValueError: If `params` is empty.
|
||||
"""
|
||||
if params in (None, (), []):
|
||||
if params is None or params in ((), []):
|
||||
raise ValueError("Need at least one param")
|
||||
if isinstance(params, variables.PartitionedVariable):
|
||||
params = list(params) # Iterate to get the underlying Variables.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user