mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Apply parts of pyupgrade to torch (starting with the safest changes). This PR only does two things: removes the need to inherit from object and removes unused future imports. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308 Approved by: https://github.com/ezyang, https://github.com/albanD
77 lines
2.2 KiB
Python
77 lines
2.2 KiB
Python
## @package store_ops_test_util
|
|
# Module caffe2.distributed.store_ops_test_util
|
|
|
|
|
|
|
|
|
|
|
|
from multiprocessing import Process, Queue
|
|
|
|
import numpy as np
|
|
|
|
from caffe2.python import core, workspace
|
|
|
|
|
|
class StoreOpsTests:
|
|
@classmethod
|
|
def _test_set_get(cls, queue, create_store_handler_fn, index, num_procs):
|
|
store_handler = create_store_handler_fn()
|
|
blob = "blob"
|
|
value = np.full(1, 1, np.float32)
|
|
|
|
# Use last process to set blob to make sure other processes
|
|
# are waiting for the blob before it is set.
|
|
if index == (num_procs - 1):
|
|
workspace.FeedBlob(blob, value)
|
|
workspace.RunOperatorOnce(
|
|
core.CreateOperator(
|
|
"StoreSet",
|
|
[store_handler, blob],
|
|
[],
|
|
blob_name=blob))
|
|
|
|
output_blob = "output_blob"
|
|
workspace.RunOperatorOnce(
|
|
core.CreateOperator(
|
|
"StoreGet",
|
|
[store_handler],
|
|
[output_blob],
|
|
blob_name=blob))
|
|
|
|
try:
|
|
np.testing.assert_array_equal(workspace.FetchBlob(output_blob), 1)
|
|
except AssertionError as err:
|
|
queue.put(err)
|
|
|
|
workspace.ResetWorkspace()
|
|
|
|
@classmethod
|
|
def test_set_get(cls, create_store_handler_fn):
|
|
# Queue for assertion errors on subprocesses
|
|
queue = Queue()
|
|
|
|
# Start N processes in the background
|
|
num_procs = 4
|
|
procs = []
|
|
for index in range(num_procs):
|
|
proc = Process(
|
|
target=cls._test_set_get,
|
|
args=(queue, create_store_handler_fn, index, num_procs, ))
|
|
proc.start()
|
|
procs.append(proc)
|
|
|
|
# Test complete, join background processes
|
|
for proc in procs:
|
|
proc.join()
|
|
|
|
# Raise first error we find, if any
|
|
if not queue.empty():
|
|
raise queue.get()
|
|
|
|
@classmethod
|
|
def test_get_timeout(cls, create_store_handler_fn):
|
|
store_handler = create_store_handler_fn()
|
|
net = core.Net('get_missing_blob')
|
|
net.StoreGet([store_handler], 1, blob_name='blob')
|
|
workspace.RunNetOnce(net)
|