Fix flakiness in GpuMultiSessionMemoryTest.

PiperOrigin-RevId: 157781368
This commit is contained in:
A. Unique TensorFlower 2017-06-01 16:45:07 -07:00 committed by TensorFlower Gardener
parent f7de292df3
commit fb4bc806a8

View File

@ -228,9 +228,9 @@ class BroadcastSimpleTest(test.TestCase):
class GpuMultiSessionMemoryTest(test_util.TensorFlowTestCase):
"""Tests concurrent sessions executing on the same GPU."""
def _run_session(self, results):
def _run_session(self, session, results):
n_iterations = 500
with self.test_session(use_gpu=True) as s:
with session as s:
data = variables.Variable(1.0)
with ops.device('/gpu:0'):
random_seed.set_random_seed(1)
@ -245,29 +245,29 @@ class GpuMultiSessionMemoryTest(test_util.TensorFlowTestCase):
for _ in xrange(n_iterations):
value = s.run(x4)
results.append(value)
if value != results[0]:
results.add(value.flat[0])
if len(results) != 1:
break
def testConcurrentSessions(self):
if not test.is_gpu_available():
return
n_threads = 4
results = [[]] * n_threads
threads = [
threading.Thread(target=self._run_session, args=(results[i],))
for i in xrange(n_threads)
]
threads = []
results = []
for _ in xrange(n_threads):
session = self.test_session(graph=ops.Graph(), use_gpu=True)
results.append(set())
args = (session, results[-1])
threads.append(threading.Thread(target=self._run_session, args=args))
for thread in threads:
thread.start()
for thread in threads:
thread.join()
flat_results = [x for x in itertools.chain(*results)]
self.assertNotEqual(0, len(flat_results))
for result in flat_results:
self.assertEqual(result, flat_results[0])
flat_results = set([x for x in itertools.chain(*results)])
self.assertEqual(1,
len(flat_results),
'Expected single value, got %r' % flat_results)
if __name__ == '__main__':