pytorch/caffe2/python/parallel_workers_test.py
Kevin Wilfong e3a7c78f04 Add shutdown_fun to parallel_workers
Summary:
parallel_workers supports calling a custom function "init_fun" when WorkerCoordinators are started which is passed in as an argument to init_workers.

Adding an analogous argument "shutdown_fun" which gets passed in to init_workers, and gets called when a WorkerCoordinator is stopped.

This allows users of the parallel_workers to add custom cleanup logic before the workers are stopped.

Reviewed By: akyrola

Differential Revision: D6020788

fbshipit-source-id: 1e1d8536a304a35fc9553407727da36446c668a3
2017-10-10 12:02:24 -07:00

126 lines
3.7 KiB
Python

# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import unittest
from caffe2.python import workspace, core
import caffe2.python.parallel_workers as parallel_workers
def create_queue():
queue = 'queue'
workspace.RunOperatorOnce(
core.CreateOperator(
"CreateBlobsQueue", [], [queue], num_blobs=1, capacity=1000
)
)
return queue
def create_worker(queue, get_blob_data):
def dummy_worker(worker_id):
blob = 'blob_' + str(worker_id)
workspace.FeedBlob(blob, get_blob_data(worker_id))
workspace.RunOperatorOnce(
core.CreateOperator(
'SafeEnqueueBlobs', [queue, blob], [blob, 'status_blob']
)
)
return dummy_worker
def dequeue_value(queue):
dequeue_blob = 'dequeue_blob'
workspace.RunOperatorOnce(
core.CreateOperator(
"SafeDequeueBlobs", [queue], [dequeue_blob, 'status_blob']
)
)
return workspace.FetchBlob(dequeue_blob)
class ParallelWorkersTest(unittest.TestCase):
def testParallelWorkers(self):
workspace.ResetWorkspace()
queue = create_queue()
dummy_worker = create_worker(queue, lambda worker_id: str(worker_id))
worker_coordinator = parallel_workers.init_workers(dummy_worker)
worker_coordinator.start()
for _ in range(10):
value = dequeue_value(queue)
self.assertTrue(
value in [b'0', b'1'], 'Got unexpected value ' + str(value)
)
self.assertTrue(worker_coordinator.stop())
def testParallelWorkersInitFun(self):
workspace.ResetWorkspace()
queue = create_queue()
dummy_worker = create_worker(
queue, lambda worker_id: workspace.FetchBlob('data')
)
workspace.FeedBlob('data', 'not initialized')
def init_fun(worker_coordinator, global_coordinator):
workspace.FeedBlob('data', 'initialized')
worker_coordinator = parallel_workers.init_workers(
dummy_worker, init_fun=init_fun
)
worker_coordinator.start()
for _ in range(10):
value = dequeue_value(queue)
self.assertEqual(
value, b'initialized', 'Got unexpected value ' + str(value)
)
self.assertTrue(worker_coordinator.stop())
def testParallelWorkersShutdownFun(self):
workspace.ResetWorkspace()
queue = create_queue()
dummy_worker = create_worker(queue, lambda worker_id: str(worker_id))
workspace.FeedBlob('data', 'not shutdown')
def shutdown_fun():
workspace.FeedBlob('data', 'shutdown')
worker_coordinator = parallel_workers.init_workers(
dummy_worker, shutdown_fun=shutdown_fun
)
worker_coordinator.start()
self.assertTrue(worker_coordinator.stop())
data = workspace.FetchBlob('data')
self.assertEqual(data, b'shutdown', 'Got unexpected value ' + str(data))