mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Fix flakiness in GpuMultiSessionMemoryTest.
PiperOrigin-RevId: 157781368
This commit is contained in:
parent
f7de292df3
commit
fb4bc806a8
|
|
@ -228,9 +228,9 @@ class BroadcastSimpleTest(test.TestCase):
|
||||||
class GpuMultiSessionMemoryTest(test_util.TensorFlowTestCase):
|
class GpuMultiSessionMemoryTest(test_util.TensorFlowTestCase):
|
||||||
"""Tests concurrent sessions executing on the same GPU."""
|
"""Tests concurrent sessions executing on the same GPU."""
|
||||||
|
|
||||||
def _run_session(self, results):
|
def _run_session(self, session, results):
|
||||||
n_iterations = 500
|
n_iterations = 500
|
||||||
with self.test_session(use_gpu=True) as s:
|
with session as s:
|
||||||
data = variables.Variable(1.0)
|
data = variables.Variable(1.0)
|
||||||
with ops.device('/gpu:0'):
|
with ops.device('/gpu:0'):
|
||||||
random_seed.set_random_seed(1)
|
random_seed.set_random_seed(1)
|
||||||
|
|
@ -245,29 +245,29 @@ class GpuMultiSessionMemoryTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
for _ in xrange(n_iterations):
|
for _ in xrange(n_iterations):
|
||||||
value = s.run(x4)
|
value = s.run(x4)
|
||||||
results.append(value)
|
results.add(value.flat[0])
|
||||||
if value != results[0]:
|
if len(results) != 1:
|
||||||
break
|
break
|
||||||
|
|
||||||
def testConcurrentSessions(self):
|
def testConcurrentSessions(self):
|
||||||
if not test.is_gpu_available():
|
|
||||||
return
|
|
||||||
|
|
||||||
n_threads = 4
|
n_threads = 4
|
||||||
results = [[]] * n_threads
|
threads = []
|
||||||
threads = [
|
results = []
|
||||||
threading.Thread(target=self._run_session, args=(results[i],))
|
for _ in xrange(n_threads):
|
||||||
for i in xrange(n_threads)
|
session = self.test_session(graph=ops.Graph(), use_gpu=True)
|
||||||
]
|
results.append(set())
|
||||||
|
args = (session, results[-1])
|
||||||
|
threads.append(threading.Thread(target=self._run_session, args=args))
|
||||||
|
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
thread.start()
|
thread.start()
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
flat_results = [x for x in itertools.chain(*results)]
|
flat_results = set([x for x in itertools.chain(*results)])
|
||||||
self.assertNotEqual(0, len(flat_results))
|
self.assertEqual(1,
|
||||||
for result in flat_results:
|
len(flat_results),
|
||||||
self.assertEqual(result, flat_results[0])
|
'Expected single value, got %r' % flat_results)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user