pytorch/caffe2/python/lazy_dyndep_test.py
Colin L Reliability Rice dfa914a90c Modify lazy_dyndep loading to trigger inside workspace. (#41687)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41687

Specifically, this makes a new library (lazy), which can be used from both core
and workspace.

This allows workspace.Createnet to trigger lazy loading of dyndep dependencies.

Test Plan: Added a unit test specifically for workspace.CreateNet

Reviewed By: dzhulgakov

Differential Revision: D22441877

fbshipit-source-id: 3a9d1af9962585d08ea2566c9c85bec7377d39f2
2020-07-22 15:36:43 -07:00

133 lines
3.9 KiB
Python

#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from hypothesis import given
import hypothesis.strategies as st
from multiprocessing import Process
import numpy as np
import tempfile
import shutil
import caffe2.python.hypothesis_test_util as hu
import unittest
op_engine = 'GLOO'
class TemporaryDirectory:
def __enter__(self):
self.tmpdir = tempfile.mkdtemp()
return self.tmpdir
def __exit__(self, type, value, traceback):
shutil.rmtree(self.tmpdir)
def allcompare_process(filestore_dir, process_id, data, num_procs):
from caffe2.python import core, data_parallel_model, workspace, lazy_dyndep
from caffe2.python.model_helper import ModelHelper
from caffe2.proto import caffe2_pb2
lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
workspace.RunOperatorOnce(
core.CreateOperator(
"FileStoreHandlerCreate", [], ["store_handler"], path=filestore_dir
)
)
rendezvous = dict(
kv_handler="store_handler",
shard_id=process_id,
num_shards=num_procs,
engine=op_engine,
exit_nets=None
)
model = ModelHelper()
model._rendezvous = rendezvous
workspace.FeedBlob("test_data", data)
data_parallel_model._RunComparison(
model, "test_data", core.DeviceOption(caffe2_pb2.CPU, 0)
)
class TestLazyDynDepAllCompare(hu.HypothesisTestCase):
@given(
d=st.integers(1, 5), n=st.integers(2, 11), num_procs=st.integers(1, 8)
)
def test_allcompare(self, d, n, num_procs):
dims = []
for _ in range(d):
dims.append(np.random.randint(1, high=n))
test_data = np.random.ranf(size=tuple(dims)).astype(np.float32)
with TemporaryDirectory() as tempdir:
processes = []
for idx in range(num_procs):
process = Process(
target=allcompare_process,
args=(tempdir, idx, test_data, num_procs)
)
processes.append(process)
process.start()
while len(processes) > 0:
process = processes.pop()
process.join()
class TestLazyDynDepError(unittest.TestCase):
def test_errorhandler(self):
from caffe2.python import core, lazy_dyndep
import tempfile
with tempfile.NamedTemporaryFile() as f:
lazy_dyndep.RegisterOpsLibrary(f.name)
def handler(e):
raise ValueError("test")
lazy_dyndep.SetErrorHandler(handler)
with self.assertRaises(ValueError, msg="test"):
core.RefreshRegisteredOperators()
def test_importaftererror(self):
from caffe2.python import core, lazy_dyndep
import tempfile
with tempfile.NamedTemporaryFile() as f:
lazy_dyndep.RegisterOpsLibrary(f.name)
def handler(e):
raise ValueError("test")
lazy_dyndep.SetErrorHandler(handler)
with self.assertRaises(ValueError):
core.RefreshRegisteredOperators()
def handlernoop(e):
raise
lazy_dyndep.SetErrorHandler(handlernoop)
lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
core.RefreshRegisteredOperators()
def test_workspacecreatenet(self):
from caffe2.python import workspace, lazy_dyndep
import tempfile
with tempfile.NamedTemporaryFile() as f:
lazy_dyndep.RegisterOpsLibrary(f.name)
called = False
def handler(e):
raise ValueError("test")
lazy_dyndep.SetErrorHandler(handler)
with self.assertRaises(ValueError, msg="test"):
workspace.CreateNet("fake")
if __name__ == "__main__":
unittest.main()