mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Make functional_ops compatible with eager exeuction by ignoring
caching devices when in eager mode PiperOrigin-RevId: 173737949
This commit is contained in:
parent
d1c59bd375
commit
245a5c171a
|
|
@ -897,6 +897,7 @@ py_library(
|
|||
":tensor_shape",
|
||||
":util",
|
||||
":variable_scope",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ def simple_scoped_fn(a, x):
|
|||
|
||||
class FunctionalOpsTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testFoldl_Simple(self):
|
||||
with self.test_session():
|
||||
elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
|
||||
|
|
@ -59,13 +60,13 @@ class FunctionalOpsTest(test.TestCase):
|
|||
r = functional_ops.foldl(
|
||||
lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
|
||||
elems)
|
||||
self.assertAllEqual(208, r.eval())
|
||||
self.assertAllEqual(208, self.evaluate(r))
|
||||
|
||||
r = functional_ops.foldl(
|
||||
lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
|
||||
elems,
|
||||
initializer=10)
|
||||
self.assertAllEqual(880, r.eval())
|
||||
self.assertAllEqual(880, self.evaluate(r))
|
||||
|
||||
def testFoldl_Scoped(self):
|
||||
with self.test_session() as sess:
|
||||
|
|
@ -78,14 +79,15 @@ class FunctionalOpsTest(test.TestCase):
|
|||
self.assertEqual(variables.trainable_variables()[0].name,
|
||||
"root/body/two:0")
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
self.assertAllEqual(208, r.eval())
|
||||
self.assertAllEqual(208, self.evaluate(r))
|
||||
|
||||
# Now let's reuse our single variable.
|
||||
varscope.reuse_variables()
|
||||
r = functional_ops.foldl(simple_scoped_fn, elems, initializer=10)
|
||||
self.assertEqual(len(variables.trainable_variables()), 1)
|
||||
self.assertAllEqual(880, r.eval())
|
||||
self.assertAllEqual(880, self.evaluate(r))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testFoldr_Simple(self):
|
||||
with self.test_session():
|
||||
elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
|
||||
|
|
@ -93,13 +95,13 @@ class FunctionalOpsTest(test.TestCase):
|
|||
r = functional_ops.foldr(
|
||||
lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
|
||||
elems)
|
||||
self.assertAllEqual(450, r.eval())
|
||||
self.assertAllEqual(450, self.evaluate(r))
|
||||
|
||||
r = functional_ops.foldr(
|
||||
lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
|
||||
elems,
|
||||
initializer=10)
|
||||
self.assertAllEqual(1282, r.eval())
|
||||
self.assertAllEqual(1282, self.evaluate(r))
|
||||
|
||||
def testFoldr_Scoped(self):
|
||||
with self.test_session() as sess:
|
||||
|
|
@ -112,13 +114,13 @@ class FunctionalOpsTest(test.TestCase):
|
|||
self.assertEqual(variables.trainable_variables()[0].name,
|
||||
"root/body/two:0")
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
self.assertAllEqual(450, r.eval())
|
||||
self.assertAllEqual(450, self.evaluate(r))
|
||||
|
||||
# Now let's reuse our single variable.
|
||||
varscope.reuse_variables()
|
||||
r = functional_ops.foldr(simple_scoped_fn, elems, initializer=10)
|
||||
self.assertEqual(len(variables.trainable_variables()), 1)
|
||||
self.assertAllEqual(1282, r.eval())
|
||||
self.assertAllEqual(1282, self.evaluate(r))
|
||||
|
||||
# pylint: disable=unnecessary-lambda
|
||||
def testFold_Grad(self):
|
||||
|
|
@ -128,21 +130,23 @@ class FunctionalOpsTest(test.TestCase):
|
|||
r = functional_ops.foldl(
|
||||
lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
|
||||
r = gradients_impl.gradients(r, v)[0]
|
||||
self.assertAllEqual(720.0, r.eval())
|
||||
self.assertAllEqual(720.0, self.evaluate(r))
|
||||
|
||||
r = functional_ops.foldr(
|
||||
lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
|
||||
r = gradients_impl.gradients(r, v)[0]
|
||||
self.assertAllEqual(720.0, r.eval())
|
||||
self.assertAllEqual(720.0, self.evaluate(r))
|
||||
# pylint: enable=unnecessary-lambda
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testMap_Simple(self):
|
||||
with self.test_session():
|
||||
nums = [1, 2, 3, 4, 5, 6]
|
||||
elems = constant_op.constant(nums, name="data")
|
||||
r = functional_ops.map_fn(
|
||||
lambda x: math_ops.multiply(math_ops.add(x, 3), 2), elems)
|
||||
self.assertAllEqual(np.array([(x + 3) * 2 for x in nums]), r.eval())
|
||||
self.assertAllEqual(
|
||||
np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
|
||||
|
||||
def testMapSparseTensor(self):
|
||||
with self.test_session():
|
||||
|
|
@ -177,13 +181,13 @@ class FunctionalOpsTest(test.TestCase):
|
|||
self.assertEqual(variables.trainable_variables()[0].name,
|
||||
"root/body/two:0")
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
self.assertAllEqual(doubles, r.eval())
|
||||
self.assertAllEqual(doubles, self.evaluate(r))
|
||||
|
||||
# Now let's reuse our single variable.
|
||||
varscope.reuse_variables()
|
||||
r = functional_ops.map_fn(double_scoped, elems)
|
||||
self.assertEqual(len(variables.trainable_variables()), 1)
|
||||
self.assertAllEqual(doubles, r.eval())
|
||||
self.assertAllEqual(doubles, self.evaluate(r))
|
||||
|
||||
def testMap_Grad(self):
|
||||
with self.test_session():
|
||||
|
|
@ -192,19 +196,22 @@ class FunctionalOpsTest(test.TestCase):
|
|||
y = functional_ops.map_fn(
|
||||
lambda x: math_ops.multiply(math_ops.square(x), param), elems)
|
||||
r = gradients_impl.gradients(y, param)[0]
|
||||
self.assertAllEqual(91.0, r.eval())
|
||||
self.assertAllEqual(91.0, self.evaluate(r))
|
||||
r = gradients_impl.gradients(y, elems)[0]
|
||||
self.assertAllEqual([4.0, 8.0, 12.0, 16.0, 20.0, 24.0], r.eval())
|
||||
self.assertAllEqual([4.0, 8.0, 12.0, 16.0, 20.0, 24.0], self.evaluate(r))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testMap_SimpleNotTensor(self):
|
||||
with self.test_session():
|
||||
nums = np.array([1, 2, 3, 4, 5, 6])
|
||||
r = functional_ops.map_fn(
|
||||
lambda x: math_ops.multiply(math_ops.add(x, 3), 2), nums)
|
||||
self.assertAllEqual(np.array([(x + 3) * 2 for x in nums]), r.eval())
|
||||
self.assertAllEqual(
|
||||
np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testMap_SingleInputMultiOutput(self):
|
||||
with self.test_session() as sess:
|
||||
with self.test_session():
|
||||
nums = np.array([1, 2, 3, 4, 5, 6])
|
||||
r = functional_ops.map_fn(
|
||||
lambda x: ((x + 3) * 2, -(x + 3) * 2),
|
||||
|
|
@ -213,10 +220,11 @@ class FunctionalOpsTest(test.TestCase):
|
|||
self.assertEqual(2, len(r))
|
||||
self.assertEqual((6,), r[0].get_shape())
|
||||
self.assertEqual((6,), r[1].get_shape())
|
||||
received = sess.run(r)
|
||||
received = self.evaluate(r)
|
||||
self.assertAllEqual((nums + 3) * 2, received[0])
|
||||
self.assertAllEqual(-(nums + 3) * 2, received[1])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testMap_MultiOutputMismatchedDtype(self):
|
||||
with self.test_session():
|
||||
nums = np.array([1, 2, 3, 4, 5, 6])
|
||||
|
|
@ -228,6 +236,7 @@ class FunctionalOpsTest(test.TestCase):
|
|||
nums,
|
||||
dtype=[dtypes.int64, dtypes.int64])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testMap_MultiInputSingleOutput(self):
|
||||
with self.test_session():
|
||||
nums = np.array([1, 2, 3, 4, 5, 6])
|
||||
|
|
@ -235,11 +244,12 @@ class FunctionalOpsTest(test.TestCase):
|
|||
lambda x: x[0] * x[1][0] + x[1][1], (nums, (nums, -nums)),
|
||||
dtype=dtypes.int64)
|
||||
self.assertEqual((6,), r.get_shape())
|
||||
received = r.eval()
|
||||
received = self.evaluate(r)
|
||||
self.assertAllEqual(nums * nums + (-nums), received)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testMap_MultiInputSameStructureOutput(self):
|
||||
with self.test_session() as sess:
|
||||
with self.test_session():
|
||||
nums = np.array([1, 2, 3, 4, 5, 6])
|
||||
r = functional_ops.map_fn(lambda x: (x[1][0], (x[1][1], x[0])),
|
||||
(nums, (2 * nums, -nums)))
|
||||
|
|
@ -247,11 +257,12 @@ class FunctionalOpsTest(test.TestCase):
|
|||
self.assertEqual((6,), r[0].get_shape())
|
||||
self.assertEqual((6,), r[1].get_shape())
|
||||
self.assertEqual((6,), r[2].get_shape())
|
||||
received = sess.run(r)
|
||||
received = self.evaluate(r)
|
||||
self.assertAllEqual(2 * nums, received[0])
|
||||
self.assertAllEqual(-nums, received[1])
|
||||
self.assertAllEqual(nums, received[2])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testScan_Simple(self):
|
||||
with self.test_session():
|
||||
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
|
||||
|
|
@ -259,24 +270,26 @@ class FunctionalOpsTest(test.TestCase):
|
|||
|
||||
# pylint: disable=unnecessary-lambda
|
||||
r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems)
|
||||
self.assertAllEqual([1., 2., 6., 24., 120., 720.], r.eval())
|
||||
self.assertAllEqual([1., 2., 6., 24., 120., 720.], self.evaluate(r))
|
||||
|
||||
r = functional_ops.scan(
|
||||
lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
|
||||
self.assertAllEqual([2., 4., 12., 48., 240., 1440.], r.eval())
|
||||
self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r))
|
||||
# pylint: enable=unnecessary-lambda
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testScan_SingleInputMultiOutput(self):
|
||||
with self.test_session() as sess:
|
||||
with self.test_session():
|
||||
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
||||
initializer = (np.array(1.0), np.array(-1.0))
|
||||
r = functional_ops.scan(lambda a, x: (a[0] * x, -a[1] * x), elems,
|
||||
initializer)
|
||||
r_value = sess.run(r)
|
||||
r_value = self.evaluate(r)
|
||||
|
||||
self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0])
|
||||
self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testScan_MultiInputSingleOutput(self):
|
||||
with self.test_session():
|
||||
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
||||
|
|
@ -284,17 +297,19 @@ class FunctionalOpsTest(test.TestCase):
|
|||
# Multiply a * 1 each time
|
||||
r = functional_ops.scan(lambda a, x: a * (x[0] + x[1]),
|
||||
(elems + 1, -elems), initializer)
|
||||
self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], r.eval())
|
||||
self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testScan_MultiInputSameTypeOutput(self):
|
||||
with self.test_session() as sess:
|
||||
with self.test_session():
|
||||
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
||||
r = functional_ops.scan(lambda a, x: (a[0] + x[0], a[1] + x[1]),
|
||||
(elems, -elems))
|
||||
r_value = sess.run(r)
|
||||
r_value = self.evaluate(r)
|
||||
self.assertAllEqual(np.cumsum(elems), r_value[0])
|
||||
self.assertAllEqual(np.cumsum(-elems), r_value[1])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testScan_MultiOutputMismatchedInitializer(self):
|
||||
with self.test_session():
|
||||
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
||||
|
|
@ -316,15 +331,16 @@ class FunctionalOpsTest(test.TestCase):
|
|||
"root/body/two:0")
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
results = np.array([1, 6, 18, 44, 98, 208])
|
||||
self.assertAllEqual(results, r.eval())
|
||||
self.assertAllEqual(results, self.evaluate(r))
|
||||
|
||||
# Now let's reuse our single variable.
|
||||
varscope.reuse_variables()
|
||||
r = functional_ops.scan(simple_scoped_fn, elems, initializer=2)
|
||||
self.assertEqual(len(variables.trainable_variables()), 1)
|
||||
results = np.array([6, 16, 38, 84, 178, 368])
|
||||
self.assertAllEqual(results, r.eval())
|
||||
self.assertAllEqual(results, self.evaluate(r))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testScanFoldl_Nested(self):
|
||||
with self.test_session():
|
||||
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data")
|
||||
|
|
@ -346,7 +362,7 @@ class FunctionalOpsTest(test.TestCase):
|
|||
# t == 3, a == 2.25, x == 4 (returns 9)
|
||||
# t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5
|
||||
# t_1 == 1, b == 4.5, y == 0.5, returns b * y * x = 9
|
||||
self.assertAllClose([1., 1., 2.25, 9.], r.eval())
|
||||
self.assertAllClose([1., 1., 2.25, 9.], self.evaluate(r))
|
||||
|
||||
def testScan_Control(self):
|
||||
with self.test_session() as sess:
|
||||
|
|
@ -369,7 +385,7 @@ class FunctionalOpsTest(test.TestCase):
|
|||
lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
|
||||
# pylint: enable=unnecessary-lambda
|
||||
r = gradients_impl.gradients(r, v)[0]
|
||||
self.assertAllEqual(873.0, r.eval())
|
||||
self.assertAllEqual(873.0, self.evaluate(r))
|
||||
|
||||
def testScanGradientWithPartStopGradient(self):
|
||||
a = variables.Variable(0.0, name="a")
|
||||
|
|
@ -383,6 +399,7 @@ class FunctionalOpsTest(test.TestCase):
|
|||
variables.global_variables_initializer().run()
|
||||
sess.run(grad)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testFoldShape(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
|
||||
|
|
@ -392,32 +409,37 @@ class FunctionalOpsTest(test.TestCase):
|
|||
|
||||
initializer = constant_op.constant([0, 0, 0])
|
||||
y = functional_ops.foldl(fn, x, initializer=initializer)
|
||||
self.assertAllEqual(y.get_shape(), y.eval().shape)
|
||||
self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testMapShape(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
|
||||
y = functional_ops.map_fn(lambda e: e, x)
|
||||
self.assertAllEqual(y.get_shape(), y.eval().shape)
|
||||
self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
|
||||
|
||||
def testMapUnknownShape(self):
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
y = functional_ops.map_fn(lambda e: e, x)
|
||||
self.assertIs(None, y.get_shape().dims)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testMapEmptyScalar(self):
|
||||
with self.test_session():
|
||||
map_return = functional_ops.map_fn(lambda x: 1, constant_op.constant([]))
|
||||
self.assertAllEqual([0], map_return.get_shape().dims)
|
||||
self.assertAllEqual([0], map_return.eval().shape)
|
||||
self.assertAllEqual([0], self.evaluate(map_return).shape)
|
||||
|
||||
# TODO(akshayka): this test fails in eager: the iterable is of length 0 so
|
||||
# so the body of the while loop never executes
|
||||
def testMapEmptyTensor(self):
|
||||
with self.test_session():
|
||||
map_return = functional_ops.map_fn(lambda x: array_ops.zeros([3, 2]),
|
||||
constant_op.constant([]))
|
||||
self.assertAllEqual([0, 3, 2], map_return.get_shape().dims)
|
||||
self.assertAllEqual([0, 3, 2], map_return.eval().shape)
|
||||
self.assertAllEqual([0, 3, 2], self.evaluate(map_return).shape)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testScanShape(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
|
||||
|
|
@ -427,14 +449,16 @@ class FunctionalOpsTest(test.TestCase):
|
|||
|
||||
initializer = constant_op.constant([0, 0, 0])
|
||||
y = functional_ops.scan(fn, x, initializer=initializer)
|
||||
self.assertAllEqual(y.get_shape(), y.eval().shape)
|
||||
self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
|
||||
|
||||
# TODO(akshayka): this test fails in eager: the iterable is of length 0 so
|
||||
# so the body of the while loop never executes
|
||||
def testScanEmptyTensor(self):
|
||||
with self.test_session():
|
||||
x = functional_ops.scan(
|
||||
lambda x, _: x, math_ops.range(0), initializer=array_ops.ones([2, 4]))
|
||||
self.assertAllEqual([0, 2, 4], x.get_shape())
|
||||
self.assertAllEqual(x.get_shape(), x.eval().shape)
|
||||
self.assertAllEqual(x.get_shape(), self.evaluate(x).shape)
|
||||
|
||||
def testScanUnknownShape(self):
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
|
|
@ -87,15 +88,20 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
|||
if not callable(fn):
|
||||
raise TypeError("fn must be callable.")
|
||||
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
with ops.name_scope(name, "foldl", [elems]):
|
||||
# Any get_variable calls in fn will cache the first call locally
|
||||
# and not issue repeated network I/O requests for each iteration.
|
||||
varscope = vs.get_variable_scope()
|
||||
varscope_caching_device_was_none = False
|
||||
if varscope.caching_device is None:
|
||||
# TODO(ebrevdo): Change to using colocate_with here and in other methods.
|
||||
varscope.set_caching_device(lambda op: op.device)
|
||||
varscope_caching_device_was_none = True
|
||||
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
|
||||
# supported in Eager
|
||||
if in_graph_mode:
|
||||
# Any get_variable calls in fn will cache the first call locally
|
||||
# and not issue repeated network I/O requests for each iteration.
|
||||
varscope = vs.get_variable_scope()
|
||||
varscope_caching_device_was_none = False
|
||||
if varscope.caching_device is None:
|
||||
# TODO(ebrevdo): Change to using colocate_with here and in other
|
||||
# methods.
|
||||
varscope.set_caching_device(lambda op: op.device)
|
||||
varscope_caching_device_was_none = True
|
||||
|
||||
# Convert elems to tensor array.
|
||||
elems = ops.convert_to_tensor(elems, name="elems")
|
||||
|
|
@ -121,7 +127,9 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
|||
back_prop=back_prop,
|
||||
swap_memory=swap_memory)
|
||||
|
||||
if varscope_caching_device_was_none:
|
||||
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
|
||||
# supported in Eager
|
||||
if in_graph_mode and varscope_caching_device_was_none:
|
||||
varscope.set_caching_device(None)
|
||||
return r_a
|
||||
|
||||
|
|
@ -167,15 +175,20 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
|||
if not callable(fn):
|
||||
raise TypeError("fn must be callable.")
|
||||
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
with ops.name_scope(name, "foldr", [elems]):
|
||||
# Any get_variable calls in fn will cache the first call locally
|
||||
# and not issue repeated network I/O requests for each iteration.
|
||||
varscope = vs.get_variable_scope()
|
||||
varscope_caching_device_was_none = False
|
||||
if varscope.caching_device is None:
|
||||
# TODO(ebrevdo): Change to using colocate_with here and in other methods.
|
||||
varscope.set_caching_device(lambda op: op.device)
|
||||
varscope_caching_device_was_none = True
|
||||
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
|
||||
# supported in Eager
|
||||
if in_graph_mode:
|
||||
# Any get_variable calls in fn will cache the first call locally and not
|
||||
# issue repeated network I/O requests for each iteration.
|
||||
varscope = vs.get_variable_scope()
|
||||
varscope_caching_device_was_none = False
|
||||
if varscope.caching_device is None:
|
||||
# TODO(ebrevdo): Change to using colocate_with here and in other
|
||||
# methods.
|
||||
varscope.set_caching_device(lambda op: op.device)
|
||||
varscope_caching_device_was_none = True
|
||||
|
||||
# Convert elems to tensor array.
|
||||
elems = ops.convert_to_tensor(elems, name="elems")
|
||||
|
|
@ -201,7 +214,9 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
|||
back_prop=back_prop,
|
||||
swap_memory=swap_memory)
|
||||
|
||||
if varscope_caching_device_was_none:
|
||||
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
|
||||
# supported in Eager
|
||||
if in_graph_mode and varscope_caching_device_was_none:
|
||||
varscope.set_caching_device(None)
|
||||
return r_a
|
||||
|
||||
|
|
@ -324,15 +339,20 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
|
|||
|
||||
elems_flat = input_flatten(elems)
|
||||
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
with ops.name_scope(name, "map", elems_flat):
|
||||
# Any get_variable calls in fn will cache the first call locally
|
||||
# and not issue repeated network I/O requests for each iteration.
|
||||
varscope = vs.get_variable_scope()
|
||||
varscope_caching_device_was_none = False
|
||||
if varscope.caching_device is None:
|
||||
# TODO(ebrevdo): Change to using colocate_with here and in other methods.
|
||||
varscope.set_caching_device(lambda op: op.device)
|
||||
varscope_caching_device_was_none = True
|
||||
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
|
||||
# supported in Eager
|
||||
if in_graph_mode:
|
||||
# Any get_variable calls in fn will cache the first call locally
|
||||
# and not issue repeated network I/O requests for each iteration.
|
||||
varscope = vs.get_variable_scope()
|
||||
varscope_caching_device_was_none = False
|
||||
if varscope.caching_device is None:
|
||||
# TODO(ebrevdo): Change to using colocate_with here and in other
|
||||
# methods.
|
||||
varscope.set_caching_device(lambda op: op.device)
|
||||
varscope_caching_device_was_none = True
|
||||
|
||||
elems_flat = [
|
||||
ops.convert_to_tensor(elem, name="elem") for elem in elems_flat]
|
||||
|
|
@ -396,7 +416,9 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
|
|||
r.set_shape(tensor_shape.TensorShape(n_static).concatenate(
|
||||
r.get_shape()[1:]))
|
||||
|
||||
if varscope_caching_device_was_none:
|
||||
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
|
||||
# supported in Eager
|
||||
if in_graph_mode and varscope_caching_device_was_none:
|
||||
varscope.set_caching_device(None)
|
||||
|
||||
return output_pack(results_flat)
|
||||
|
|
@ -509,15 +531,20 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
|||
|
||||
elems_flat = input_flatten(elems)
|
||||
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
with ops.name_scope(name, "scan", elems_flat):
|
||||
# Any get_variable calls in fn will cache the first call locally
|
||||
# and not issue repeated network I/O requests for each iteration.
|
||||
varscope = vs.get_variable_scope()
|
||||
varscope_caching_device_was_none = False
|
||||
if varscope.caching_device is None:
|
||||
# TODO(ebrevdo): Change to using colocate_with here and in other methods.
|
||||
varscope.set_caching_device(lambda op: op.device)
|
||||
varscope_caching_device_was_none = True
|
||||
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
|
||||
# supported in Eager
|
||||
if in_graph_mode:
|
||||
# Any get_variable calls in fn will cache the first call locally
|
||||
# and not issue repeated network I/O requests for each iteration.
|
||||
varscope = vs.get_variable_scope()
|
||||
varscope_caching_device_was_none = False
|
||||
if varscope.caching_device is None:
|
||||
# TODO(ebrevdo): Change to using colocate_with here and in other
|
||||
# methods.
|
||||
varscope.set_caching_device(lambda op: op.device)
|
||||
varscope_caching_device_was_none = True
|
||||
|
||||
# Convert elems to tensor array.
|
||||
elems_flat = [
|
||||
|
|
@ -594,7 +621,9 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
|||
r.set_shape(tensor_shape.TensorShape(n_static).concatenate(
|
||||
r.get_shape()[1:]))
|
||||
|
||||
if varscope_caching_device_was_none:
|
||||
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
|
||||
# supported in Eager
|
||||
if in_graph_mode and varscope_caching_device_was_none:
|
||||
varscope.set_caching_device(None)
|
||||
|
||||
return output_pack(results_flat)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user