#!/usr/bin/env python3 from hypothesis import given, settings import hypothesis.strategies as st from multiprocessing import Process import numpy as np import tempfile import shutil import caffe2.python.hypothesis_test_util as hu import unittest op_engine = 'GLOO' class TemporaryDirectory: def __enter__(self): self.tmpdir = tempfile.mkdtemp() return self.tmpdir def __exit__(self, type, value, traceback): shutil.rmtree(self.tmpdir) def allcompare_process(filestore_dir, process_id, data, num_procs): from caffe2.python import core, data_parallel_model, workspace, lazy_dyndep from caffe2.python.model_helper import ModelHelper from caffe2.proto import caffe2_pb2 lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops") workspace.RunOperatorOnce( core.CreateOperator( "FileStoreHandlerCreate", [], ["store_handler"], path=filestore_dir ) ) rendezvous = dict( kv_handler="store_handler", shard_id=process_id, num_shards=num_procs, engine=op_engine, exit_nets=None ) model = ModelHelper() model._rendezvous = rendezvous workspace.FeedBlob("test_data", data) data_parallel_model._RunComparison( model, "test_data", core.DeviceOption(caffe2_pb2.CPU, 0) ) class TestLazyDynDepAllCompare(hu.HypothesisTestCase): @given( d=st.integers(1, 5), n=st.integers(2, 11), num_procs=st.integers(1, 8) ) @settings(deadline=None) def test_allcompare(self, d, n, num_procs): dims = [] for _ in range(d): dims.append(np.random.randint(1, high=n)) test_data = np.random.ranf(size=tuple(dims)).astype(np.float32) with TemporaryDirectory() as tempdir: processes = [] for idx in range(num_procs): process = Process( target=allcompare_process, args=(tempdir, idx, test_data, num_procs) ) processes.append(process) process.start() while len(processes) > 0: process = processes.pop() process.join() class TestLazyDynDepError(unittest.TestCase): def test_errorhandler(self): from caffe2.python import core, lazy_dyndep import tempfile with tempfile.NamedTemporaryFile() as f: lazy_dyndep.RegisterOpsLibrary(f.name) def handler(e): raise ValueError("test") lazy_dyndep.SetErrorHandler(handler) with self.assertRaises(ValueError, msg="test"): core.RefreshRegisteredOperators() def test_importaftererror(self): from caffe2.python import core, lazy_dyndep import tempfile with tempfile.NamedTemporaryFile() as f: lazy_dyndep.RegisterOpsLibrary(f.name) def handler(e): raise ValueError("test") lazy_dyndep.SetErrorHandler(handler) with self.assertRaises(ValueError): core.RefreshRegisteredOperators() def handlernoop(e): raise lazy_dyndep.SetErrorHandler(handlernoop) lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops") core.RefreshRegisteredOperators() def test_workspacecreatenet(self): from caffe2.python import workspace, lazy_dyndep import tempfile with tempfile.NamedTemporaryFile() as f: lazy_dyndep.RegisterOpsLibrary(f.name) called = False def handler(e): raise ValueError("test") lazy_dyndep.SetErrorHandler(handler) with self.assertRaises(ValueError, msg="test"): workspace.CreateNet("fake") if __name__ == "__main__": unittest.main()