pytorch/caffe2/python/lazy_dyndep_test.py
Colin L Reliability Rice 07fd5f8ff9 Create lazy_dyndeps to avoid caffe2 import costs. (#39488)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39488

Currently caffe2.InitOpLibrary does the dll import uniliaterally. Instead if we make a lazy version and use it, then many pieces of code which do not need the caffe2urrenoperators get a lot faster.

One a real test, the import time went from 140s to 68s. 8s.

This also cleans up the algorithm slightly (although it makes a very minimal
difference), by parsing the list of operators once, rather than every time a
new operator is added, since we defer the RefreshCall until after we've
imported all the operators.

The key way we maintain safety, is that as soon as someone does an operation
which requires a operator (or could), we force importing of all available
operators.

Future work could include trying to identify which code is needed for which
operator and only import the needed ones. There may also be wins available by
playing with dlmopen (which opens within a namespace), or seeing if the dl
flags have an impact (I tried this and didn't see an impact, but dlmopen may
make it better).

Test Plan:
I added a new test a lazy_dyndep_test.py (copied from all_compare_test.py).
I'm a little concerned that I don't see any explicit tests for dyndep, but this
should provide decent coverage.

Differential Revision: D21870844

fbshipit-source-id: 3f65fedb65bb48663670349cee5e1d3e22d560ed
2020-07-09 11:34:57 -07:00

99 lines
2.8 KiB
Python

#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from hypothesis import given
import hypothesis.strategies as st
from multiprocessing import Process
import numpy as np
import tempfile
import shutil
import caffe2.python.hypothesis_test_util as hu
import unittest
op_engine = 'GLOO'
class TemporaryDirectory:
def __enter__(self):
self.tmpdir = tempfile.mkdtemp()
return self.tmpdir
def __exit__(self, type, value, traceback):
shutil.rmtree(self.tmpdir)
def allcompare_process(filestore_dir, process_id, data, num_procs):
from caffe2.python import core, data_parallel_model, workspace, lazy_dyndep
from caffe2.python.model_helper import ModelHelper
from caffe2.proto import caffe2_pb2
lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
workspace.RunOperatorOnce(
core.CreateOperator(
"FileStoreHandlerCreate", [], ["store_handler"], path=filestore_dir
)
)
rendezvous = dict(
kv_handler="store_handler",
shard_id=process_id,
num_shards=num_procs,
engine=op_engine,
exit_nets=None
)
model = ModelHelper()
model._rendezvous = rendezvous
workspace.FeedBlob("test_data", data)
data_parallel_model._RunComparison(
model, "test_data", core.DeviceOption(caffe2_pb2.CPU, 0)
)
class TestLazyDynDepAllCompare(hu.HypothesisTestCase):
@given(
d=st.integers(1, 5), n=st.integers(2, 11), num_procs=st.integers(1, 8)
)
def test_allcompare(self, d, n, num_procs):
dims = []
for _ in range(d):
dims.append(np.random.randint(1, high=n))
test_data = np.random.ranf(size=tuple(dims)).astype(np.float32)
with TemporaryDirectory() as tempdir:
processes = []
for idx in range(num_procs):
process = Process(
target=allcompare_process,
args=(tempdir, idx, test_data, num_procs)
)
processes.append(process)
process.start()
while len(processes) > 0:
process = processes.pop()
process.join()
class TestLazyDynDepError(unittest.TestCase):
def test_errorhandler(self):
from caffe2.python import core, lazy_dyndep
import tempfile
with tempfile.NamedTemporaryFile() as f:
lazy_dyndep.RegisterOpsLibrary(f.name)
def handler(e):
raise ValueError("test")
lazy_dyndep.SetErrorHandler(handler)
with self.assertRaises(ValueError, msg="test"):
core.RefreshRegisteredOperators()
if __name__ == "__main__":
unittest.main()