pytorch/caffe2/distributed/store_ops_test_util.py
Aaron Gokaslan 8fce9a09cd [BE]: pyupgrade Python to 3.8 - imports and object inheritance only (#94308)
Apply parts of pyupgrade to torch (starting with the safest changes).
This PR only does two things: removes the need to inherit from object and removes unused future imports.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-07 21:10:56 +00:00

77 lines
2.2 KiB
Python

## @package store_ops_test_util
# Module caffe2.distributed.store_ops_test_util
from multiprocessing import Process, Queue
import numpy as np
from caffe2.python import core, workspace
class StoreOpsTests:
@classmethod
def _test_set_get(cls, queue, create_store_handler_fn, index, num_procs):
store_handler = create_store_handler_fn()
blob = "blob"
value = np.full(1, 1, np.float32)
# Use last process to set blob to make sure other processes
# are waiting for the blob before it is set.
if index == (num_procs - 1):
workspace.FeedBlob(blob, value)
workspace.RunOperatorOnce(
core.CreateOperator(
"StoreSet",
[store_handler, blob],
[],
blob_name=blob))
output_blob = "output_blob"
workspace.RunOperatorOnce(
core.CreateOperator(
"StoreGet",
[store_handler],
[output_blob],
blob_name=blob))
try:
np.testing.assert_array_equal(workspace.FetchBlob(output_blob), 1)
except AssertionError as err:
queue.put(err)
workspace.ResetWorkspace()
@classmethod
def test_set_get(cls, create_store_handler_fn):
# Queue for assertion errors on subprocesses
queue = Queue()
# Start N processes in the background
num_procs = 4
procs = []
for index in range(num_procs):
proc = Process(
target=cls._test_set_get,
args=(queue, create_store_handler_fn, index, num_procs, ))
proc.start()
procs.append(proc)
# Test complete, join background processes
for proc in procs:
proc.join()
# Raise first error we find, if any
if not queue.empty():
raise queue.get()
@classmethod
def test_get_timeout(cls, create_store_handler_fn):
store_handler = create_store_handler_fn()
net = core.Net('get_missing_blob')
net.StoreGet([store_handler], 1, blob_name='blob')
workspace.RunNetOnce(net)