mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Fix flakiness in GpuMultiSessionMemoryTest.
PiperOrigin-RevId: 157781368
This commit is contained in:
parent
f7de292df3
commit
fb4bc806a8
|
|
@ -228,9 +228,9 @@ class BroadcastSimpleTest(test.TestCase):
|
|||
class GpuMultiSessionMemoryTest(test_util.TensorFlowTestCase):
|
||||
"""Tests concurrent sessions executing on the same GPU."""
|
||||
|
||||
def _run_session(self, results):
|
||||
def _run_session(self, session, results):
|
||||
n_iterations = 500
|
||||
with self.test_session(use_gpu=True) as s:
|
||||
with session as s:
|
||||
data = variables.Variable(1.0)
|
||||
with ops.device('/gpu:0'):
|
||||
random_seed.set_random_seed(1)
|
||||
|
|
@ -245,29 +245,29 @@ class GpuMultiSessionMemoryTest(test_util.TensorFlowTestCase):
|
|||
|
||||
for _ in xrange(n_iterations):
|
||||
value = s.run(x4)
|
||||
results.append(value)
|
||||
if value != results[0]:
|
||||
results.add(value.flat[0])
|
||||
if len(results) != 1:
|
||||
break
|
||||
|
||||
def testConcurrentSessions(self):
|
||||
if not test.is_gpu_available():
|
||||
return
|
||||
|
||||
n_threads = 4
|
||||
results = [[]] * n_threads
|
||||
threads = [
|
||||
threading.Thread(target=self._run_session, args=(results[i],))
|
||||
for i in xrange(n_threads)
|
||||
]
|
||||
threads = []
|
||||
results = []
|
||||
for _ in xrange(n_threads):
|
||||
session = self.test_session(graph=ops.Graph(), use_gpu=True)
|
||||
results.append(set())
|
||||
args = (session, results[-1])
|
||||
threads.append(threading.Thread(target=self._run_session, args=args))
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
flat_results = [x for x in itertools.chain(*results)]
|
||||
self.assertNotEqual(0, len(flat_results))
|
||||
for result in flat_results:
|
||||
self.assertEqual(result, flat_results[0])
|
||||
flat_results = set([x for x in itertools.chain(*results)])
|
||||
self.assertEqual(1,
|
||||
len(flat_results),
|
||||
'Expected single value, got %r' % flat_results)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user