mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Applies PLW0108 which removes useless lambda calls in Python, the rule is in preview so it is not ready to be enabled by default just yet. These are the autofixes from the rule. Pull Request resolved: https://github.com/pytorch/pytorch/pull/113602 Approved by: https://github.com/albanD
120 lines
3.4 KiB
Python
120 lines
3.4 KiB
Python
|
|
|
|
|
|
|
|
|
|
import unittest
|
|
|
|
from caffe2.python import workspace, core
|
|
import caffe2.python.parallel_workers as parallel_workers
|
|
|
|
|
|
def create_queue():
|
|
queue = 'queue'
|
|
|
|
workspace.RunOperatorOnce(
|
|
core.CreateOperator(
|
|
"CreateBlobsQueue", [], [queue], num_blobs=1, capacity=1000
|
|
)
|
|
)
|
|
# Technically, blob creations aren't thread safe. Since the unittest below
|
|
# does RunOperatorOnce instead of CreateNet+RunNet, we have to precreate
|
|
# all blobs beforehand
|
|
for i in range(100):
|
|
workspace.C.Workspace.current.create_blob("blob_" + str(i))
|
|
workspace.C.Workspace.current.create_blob("status_blob_" + str(i))
|
|
workspace.C.Workspace.current.create_blob("dequeue_blob")
|
|
workspace.C.Workspace.current.create_blob("status_blob")
|
|
|
|
return queue
|
|
|
|
|
|
def create_worker(queue, get_blob_data):
|
|
def dummy_worker(worker_id):
|
|
blob = 'blob_' + str(worker_id)
|
|
|
|
workspace.FeedBlob(blob, get_blob_data(worker_id))
|
|
|
|
workspace.RunOperatorOnce(
|
|
core.CreateOperator(
|
|
'SafeEnqueueBlobs', [queue, blob], [blob, 'status_blob_' + str(worker_id)]
|
|
)
|
|
)
|
|
|
|
return dummy_worker
|
|
|
|
|
|
def dequeue_value(queue):
|
|
dequeue_blob = 'dequeue_blob'
|
|
workspace.RunOperatorOnce(
|
|
core.CreateOperator(
|
|
"SafeDequeueBlobs", [queue], [dequeue_blob, 'status_blob']
|
|
)
|
|
)
|
|
|
|
return workspace.FetchBlob(dequeue_blob)
|
|
|
|
|
|
class ParallelWorkersTest(unittest.TestCase):
|
|
def testParallelWorkers(self):
|
|
workspace.ResetWorkspace()
|
|
|
|
queue = create_queue()
|
|
dummy_worker = create_worker(queue, str)
|
|
worker_coordinator = parallel_workers.init_workers(dummy_worker)
|
|
worker_coordinator.start()
|
|
|
|
for _ in range(10):
|
|
value = dequeue_value(queue)
|
|
self.assertTrue(
|
|
value in [b'0', b'1'], 'Got unexpected value ' + str(value)
|
|
)
|
|
|
|
self.assertTrue(worker_coordinator.stop())
|
|
|
|
def testParallelWorkersInitFun(self):
|
|
workspace.ResetWorkspace()
|
|
|
|
queue = create_queue()
|
|
dummy_worker = create_worker(
|
|
queue, lambda worker_id: workspace.FetchBlob('data')
|
|
)
|
|
workspace.FeedBlob('data', 'not initialized')
|
|
|
|
def init_fun(worker_coordinator, global_coordinator):
|
|
workspace.FeedBlob('data', 'initialized')
|
|
|
|
worker_coordinator = parallel_workers.init_workers(
|
|
dummy_worker, init_fun=init_fun
|
|
)
|
|
worker_coordinator.start()
|
|
|
|
for _ in range(10):
|
|
value = dequeue_value(queue)
|
|
self.assertEqual(
|
|
value, b'initialized', 'Got unexpected value ' + str(value)
|
|
)
|
|
|
|
# A best effort attempt at a clean shutdown
|
|
worker_coordinator.stop()
|
|
|
|
def testParallelWorkersShutdownFun(self):
|
|
workspace.ResetWorkspace()
|
|
|
|
queue = create_queue()
|
|
dummy_worker = create_worker(queue, str)
|
|
workspace.FeedBlob('data', 'not shutdown')
|
|
|
|
def shutdown_fun():
|
|
workspace.FeedBlob('data', 'shutdown')
|
|
|
|
worker_coordinator = parallel_workers.init_workers(
|
|
dummy_worker, shutdown_fun=shutdown_fun
|
|
)
|
|
worker_coordinator.start()
|
|
|
|
self.assertTrue(worker_coordinator.stop())
|
|
|
|
data = workspace.FetchBlob('data')
|
|
self.assertEqual(data, b'shutdown', 'Got unexpected value ' + str(data))
|