pytorch/caffe2/python/parallel_workers_test.py
Kevin Wilfong d072701547 Caffe2: Refactor the core logic from data_workers.py into parallel_workers.py
Summary:
data_workers.py provides a really nice, easy way to run background threads for data input.  Unfortunately, it's restrictive, the output of the fetcher function has to be a numpy array.

I pulled out that core nice thread management into parallel_workers, and updated the classes data_workers to extend those classes.  The main change was refactoring out most of the queue handling logic into QueueManager.

This way parallel_workers can be used to manage background threads without having to use the queue for output.

Reviewed By: akyrola

Differential Revision: D5538626

fbshipit-source-id: f382cc43f800ff90840582a378dc9b86ac05b613
2017-08-07 10:14:08 -07:00

91 lines
2.4 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import unittest
from caffe2.python import workspace, core
import caffe2.python.parallel_workers as parallel_workers
def create_queue():
queue = 'queue'
workspace.RunOperatorOnce(
core.CreateOperator(
"CreateBlobsQueue", [], [queue], num_blobs=1, capacity=1000
)
)
return queue
def create_worker(queue, get_blob_data):
def dummy_worker(worker_id):
blob = 'blob_' + str(worker_id)
workspace.FeedBlob(blob, get_blob_data(worker_id))
workspace.RunOperatorOnce(
core.CreateOperator(
'SafeEnqueueBlobs', [queue, blob], [blob, 'status_blob']
)
)
return dummy_worker
def dequeue_value(queue):
dequeue_blob = 'dequeue_blob'
workspace.RunOperatorOnce(
core.CreateOperator(
"SafeDequeueBlobs", [queue], [dequeue_blob, 'status_blob']
)
)
return workspace.FetchBlob(dequeue_blob)
class ParallelWorkersTest(unittest.TestCase):
def testParallelWorkers(self):
workspace.ResetWorkspace()
queue = create_queue()
dummy_worker = create_worker(queue, lambda worker_id: str(worker_id))
worker_coordinator = parallel_workers.init_workers(dummy_worker)
worker_coordinator.start()
for _ in range(10):
value = dequeue_value(queue)
self.assertTrue(
value in [b'0', b'1'], 'Got unexpected value ' + str(value)
)
self.assertTrue(worker_coordinator.stop())
def testParallelWorkersInitFun(self):
workspace.ResetWorkspace()
queue = create_queue()
dummy_worker = create_worker(
queue, lambda worker_id: workspace.FetchBlob('data')
)
workspace.FeedBlob('data', 'not initialized')
def init_fun(worker_coordinator, global_coordinator):
workspace.FeedBlob('data', 'initialized')
worker_coordinator = parallel_workers.init_workers(
dummy_worker, init_fun=init_fun
)
worker_coordinator.start()
for _ in range(10):
value = dequeue_value(queue)
self.assertEqual(
value, b'initialized', 'Got unexpected value ' + str(value)
)
self.assertTrue(worker_coordinator.stop())