Fix flakiness in GpuMultiSessionMemoryTest.

PiperOrigin-RevId: 157781368
This commit is contained in:
A. Unique TensorFlower 2017-06-01 16:45:07 -07:00 committed by TensorFlower Gardener
parent f7de292df3
commit fb4bc806a8

View File

@ -228,9 +228,9 @@ class BroadcastSimpleTest(test.TestCase):
class GpuMultiSessionMemoryTest(test_util.TensorFlowTestCase): class GpuMultiSessionMemoryTest(test_util.TensorFlowTestCase):
"""Tests concurrent sessions executing on the same GPU.""" """Tests concurrent sessions executing on the same GPU."""
def _run_session(self, results): def _run_session(self, session, results):
n_iterations = 500 n_iterations = 500
with self.test_session(use_gpu=True) as s: with session as s:
data = variables.Variable(1.0) data = variables.Variable(1.0)
with ops.device('/gpu:0'): with ops.device('/gpu:0'):
random_seed.set_random_seed(1) random_seed.set_random_seed(1)
@ -245,29 +245,29 @@ class GpuMultiSessionMemoryTest(test_util.TensorFlowTestCase):
for _ in xrange(n_iterations): for _ in xrange(n_iterations):
value = s.run(x4) value = s.run(x4)
results.append(value) results.add(value.flat[0])
if value != results[0]: if len(results) != 1:
break break
def testConcurrentSessions(self): def testConcurrentSessions(self):
if not test.is_gpu_available():
return
n_threads = 4 n_threads = 4
results = [[]] * n_threads threads = []
threads = [ results = []
threading.Thread(target=self._run_session, args=(results[i],)) for _ in xrange(n_threads):
for i in xrange(n_threads) session = self.test_session(graph=ops.Graph(), use_gpu=True)
] results.append(set())
args = (session, results[-1])
threads.append(threading.Thread(target=self._run_session, args=args))
for thread in threads: for thread in threads:
thread.start() thread.start()
for thread in threads: for thread in threads:
thread.join() thread.join()
flat_results = [x for x in itertools.chain(*results)] flat_results = set([x for x in itertools.chain(*results)])
self.assertNotEqual(0, len(flat_results)) self.assertEqual(1,
for result in flat_results: len(flat_results),
self.assertEqual(result, flat_results[0]) 'Expected single value, got %r' % flat_results)
if __name__ == '__main__': if __name__ == '__main__':