mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Clean up a couple of items in the C2 test scaffolding (WIP) (#7847)
Summary: - Py3 compatibility - utility functions refactoring Pull Request resolved: https://github.com/pytorch/pytorch/pull/7847 Reviewed By: pietern Differential Revision: D9355096 Pulled By: huitseeker fbshipit-source-id: 8e78faa937488c5299714f78075d7cadb1b2490c
This commit is contained in:
parent
10fdcf748a
commit
edd2e38023
|
|
@ -50,6 +50,7 @@ import hypothesis.strategies as st
|
|||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import six
|
||||
|
||||
|
||||
def is_sandcastle():
|
||||
|
|
@ -355,6 +356,7 @@ class HypothesisTestCase(test_util.TestCase):
|
|||
A unittest.TestCase subclass with some helper functions for
|
||||
utilizing the `hypothesis` (hypothesis.readthedocs.io) library.
|
||||
"""
|
||||
|
||||
def assertDeviceChecks(
|
||||
self,
|
||||
device_options,
|
||||
|
|
@ -689,5 +691,5 @@ class HypothesisTestCase(test_util.TestCase):
|
|||
if regexp is None:
|
||||
self.assertRaises(exception, workspace.RunOperatorOnce, op)
|
||||
else:
|
||||
self.assertRaisesRegexp(
|
||||
exception, regexp, workspace.RunOperatorOnce, op)
|
||||
six.assertRaisesRegex(
|
||||
self, exception, regexp, workspace.RunOperatorOnce, op)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from caffe2.python.modeling.parameter_sharing import (
|
|||
ParameterSharing,
|
||||
)
|
||||
from caffe2.python.layer_test_util import LayersTestCase
|
||||
import six
|
||||
|
||||
|
||||
class ParameterSharingTest(LayersTestCase):
|
||||
|
|
@ -114,7 +115,7 @@ class ParameterSharingTest(LayersTestCase):
|
|||
self.assertEquals(self.model.layers[-1].w,
|
||||
'global_scope/fc/w')
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'Got inconsistent shapes .*'):
|
||||
with six.assertRaisesRegex(self, ValueError, 'Got inconsistent shapes .*'):
|
||||
self.model.FC(
|
||||
self.model.input_feature_schema.float_features,
|
||||
output_dims + 1
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from __future__ import unicode_literals
|
|||
|
||||
import json
|
||||
import os
|
||||
import six
|
||||
import unittest
|
||||
|
||||
from caffe2.python import core
|
||||
|
|
@ -24,12 +25,12 @@ import caffe2.python.onnx.backend as c2
|
|||
import numpy as np
|
||||
from caffe2.python.models.download import downloadFromURLToFile, getURLFromName, deleteDirectory
|
||||
|
||||
from caffe2.python.onnx.tests.test_utils import TestCase
|
||||
from caffe2.python.onnx.tests.test_utils import DownloadingTestCase
|
||||
|
||||
import caffe2.python._import_c_extension as C
|
||||
|
||||
|
||||
class TestCaffe2Basic(TestCase):
|
||||
class TestCaffe2Basic(DownloadingTestCase):
|
||||
def test_dummy_name(self):
|
||||
g = C.DummyName()
|
||||
n1 = g.new_dummy_name()
|
||||
|
|
@ -43,9 +44,9 @@ class TestCaffe2Basic(TestCase):
|
|||
b2.convert_node(node_def.SerializeToString())
|
||||
|
||||
bad_node_def = make_node("Add", inputs=["X", "Y"], outputs=["Z"], foo=42, bar=56)
|
||||
with self.assertRaisesRegexp(
|
||||
RuntimeError,
|
||||
"Don't know how to map unexpected argument (foo|bar)"):
|
||||
with six.assertRaisesRegex(self,
|
||||
RuntimeError,
|
||||
"Don't know how to map unexpected argument (foo|bar)"):
|
||||
b2.convert_node(bad_node_def.SerializeToString())
|
||||
|
||||
def test_relu_graph(self):
|
||||
|
|
@ -550,7 +551,7 @@ class TestCaffe2Basic(TestCase):
|
|||
self.assertSameOutputs(c2_outputs, onnx_outputs)
|
||||
|
||||
|
||||
class TestCaffe2End2End(TestCase):
|
||||
class TestCaffe2End2End(DownloadingTestCase):
|
||||
def _model_dir(self, model):
|
||||
caffe2_home = os.path.expanduser(os.getenv('CAFFE2_HOME', '~/.caffe2'))
|
||||
models_dir = os.getenv('ONNX_MODELS', os.path.join(caffe2_home, 'models'))
|
||||
|
|
@ -582,36 +583,15 @@ class TestCaffe2End2End(TestCase):
|
|||
_, c2_outputs = c2_native_run_net(c2_init_net, c2_predict_net, inputs)
|
||||
del _
|
||||
|
||||
model = c2_onnx.caffe2_net_to_onnx_model(
|
||||
predict_net=c2_predict_net,
|
||||
init_net=c2_init_net,
|
||||
value_info=json.load(open(os.path.join(model_dir, 'value_info.json'))))
|
||||
with open(os.path.join(model_dir, 'value_info.json'), 'r') as value_info_conf:
|
||||
model = c2_onnx.caffe2_net_to_onnx_model(
|
||||
predict_net=c2_predict_net,
|
||||
init_net=c2_init_net,
|
||||
value_info=json.load(value_info_conf))
|
||||
c2_ir = c2.prepare(model)
|
||||
onnx_outputs = c2_ir.run(inputs)
|
||||
self.assertSameOutputs(c2_outputs, onnx_outputs, decimal=decimal)
|
||||
|
||||
def _download(self, model):
|
||||
model_dir = self._model_dir(model)
|
||||
assert not os.path.exists(model_dir)
|
||||
os.makedirs(model_dir)
|
||||
for f in ['predict_net.pb', 'init_net.pb', 'value_info.json']:
|
||||
url = getURLFromName(model, f)
|
||||
dest = os.path.join(model_dir, f)
|
||||
try:
|
||||
try:
|
||||
downloadFromURLToFile(url, dest,
|
||||
show_progress=False)
|
||||
except TypeError:
|
||||
# show_progress not supported prior to
|
||||
# Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1
|
||||
# (Sep 17, 2017)
|
||||
downloadFromURLToFile(url, dest)
|
||||
except Exception as e:
|
||||
print("Abort: {reason}".format(reason=e))
|
||||
print("Cleaning up...")
|
||||
deleteDirectory(model_dir)
|
||||
exit(1)
|
||||
|
||||
def test_alexnet(self):
|
||||
self._test_net('bvlc_alexnet', decimal=4)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import six
|
||||
import tempfile
|
||||
import textwrap
|
||||
import traceback
|
||||
|
|
@ -55,7 +56,7 @@ class TestConversion(TestCase):
|
|||
caffe2_init_net.write(init_model.net.Proto().SerializeToString())
|
||||
caffe2_init_net.flush()
|
||||
|
||||
result = self._run_command(
|
||||
self._run_command(
|
||||
caffe2_to_onnx, [
|
||||
caffe2_net.name,
|
||||
'--caffe2-init-net', caffe2_init_net.name,
|
||||
|
|
@ -81,16 +82,16 @@ class TestConversion(TestCase):
|
|||
caffe2_net.flush()
|
||||
|
||||
args = [caffe2_net.name, '--output', output.name]
|
||||
self.assertRaisesRegexp(Exception,
|
||||
'value info',
|
||||
self._run_command, caffe2_to_onnx, args)
|
||||
six.assertRaisesRegex(self, Exception,
|
||||
'value info',
|
||||
self._run_command, caffe2_to_onnx, args)
|
||||
|
||||
args.extend([
|
||||
'--value-info',
|
||||
json.dumps({
|
||||
'X': (TensorProto.FLOAT, (2, 2)),
|
||||
})])
|
||||
result = self._run_command(caffe2_to_onnx, args)
|
||||
self._run_command(caffe2_to_onnx, args)
|
||||
|
||||
onnx_model = ModelProto()
|
||||
onnx_model.ParseFromString(output.read())
|
||||
|
|
@ -119,7 +120,7 @@ class TestConversion(TestCase):
|
|||
onnx_model.write(model_def.SerializeToString())
|
||||
onnx_model.flush()
|
||||
|
||||
result = self._run_command(
|
||||
self._run_command(
|
||||
onnx_to_caffe2, [
|
||||
onnx_model.name,
|
||||
'--output', output.name,
|
||||
|
|
@ -138,12 +139,9 @@ class TestConversion(TestCase):
|
|||
for init_op in caffe2_init_net.op], [])),
|
||||
{'W'})
|
||||
|
||||
|
||||
def test_onnx_to_caffe2_zipfile(self):
|
||||
buf = tempfile.NamedTemporaryFile()
|
||||
onnx_model = zipfile.ZipFile(buf, 'w')
|
||||
output = tempfile.NamedTemporaryFile()
|
||||
init_net_output = tempfile.NamedTemporaryFile()
|
||||
|
||||
node_def = helper.make_node(
|
||||
"MatMul", ["X", "W"], ["Y"])
|
||||
|
|
|
|||
|
|
@ -6,12 +6,15 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from caffe2.python.models.download import downloadFromURLToFile, getURLFromName, deleteDirectory
|
||||
|
||||
|
||||
class TestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
np.random.seed(seed=0)
|
||||
|
||||
|
|
@ -21,9 +24,34 @@ class TestCase(unittest.TestCase):
|
|||
self.assertEqual(o1.dtype, o2.dtype)
|
||||
np.testing.assert_almost_equal(o1, o2, decimal=decimal)
|
||||
|
||||
def add_test_case(name, test_func):
|
||||
def add_test_case(self, name, test_func):
|
||||
if not name.startswith('test_'):
|
||||
raise ValueError('Test name must start with test_: {}'.format(name))
|
||||
if hasattr(self, name):
|
||||
raise ValueError('Duplicated test name: {}'.format(name))
|
||||
setattr(self, name, test_func)
|
||||
|
||||
|
||||
class DownloadingTestCase(TestCase):
|
||||
|
||||
def _download(self, model):
|
||||
model_dir = self._model_dir(model)
|
||||
assert not os.path.exists(model_dir)
|
||||
os.makedirs(model_dir)
|
||||
for f in ['predict_net.pb', 'init_net.pb', 'value_info.json']:
|
||||
url = getURLFromName(model, f)
|
||||
dest = os.path.join(model_dir, f)
|
||||
try:
|
||||
try:
|
||||
downloadFromURLToFile(url, dest,
|
||||
show_progress=False)
|
||||
except TypeError:
|
||||
# show_progress not supported prior to
|
||||
# Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1
|
||||
# (Sep 17, 2017)
|
||||
downloadFromURLToFile(url, dest)
|
||||
except Exception as e:
|
||||
print("Abort: {reason}".format(reason=e))
|
||||
print("Cleaning up...")
|
||||
deleteDirectory(model_dir)
|
||||
raise AssertionError("Test model downloading failed")
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from caffe2.python import core, workspace
|
||||
from caffe2.python.test_util import TestCase
|
||||
|
|
@ -61,9 +62,9 @@ class TestLengthsToShapeOps(TestCase):
|
|||
_test_reshape(old_shape=(4, 3, 2), new_shape=(-1, 0),
|
||||
expected_shape=(8, 3), arg_shape=False)
|
||||
|
||||
self.assertRaisesRegexp(RuntimeError, "size is zero",
|
||||
_test_reshape, old_shape=(2, 0), new_shape=(-1, 0),
|
||||
expected_shape=(2, 0), arg_shape=False)
|
||||
with six.assertRaisesRegex(self, RuntimeError, "size is zero"):
|
||||
_test_reshape(old_shape=(2, 0), new_shape=(-1, 0),
|
||||
expected_shape=(2, 0), arg_shape=False)
|
||||
|
||||
def test_backprop(self):
|
||||
old_shape = (4, 2, 1)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import caffe2.python.serialized_test.serialized_test_util as serial
|
|||
import hypothesis.strategies as st
|
||||
import numpy as np
|
||||
import random
|
||||
import six
|
||||
import unittest
|
||||
|
||||
|
||||
|
|
@ -403,13 +404,13 @@ class TestUtilityOps(serial.SerializedTestCase):
|
|||
)
|
||||
self.assertDeviceChecks(dc, op, inputs, [0])
|
||||
|
||||
with self.assertRaisesRegexp(RuntimeError, 'Step size cannot be 0'):
|
||||
inputs = (np.array(0), np.array(10), np.array(0))
|
||||
op = core.CreateOperator(
|
||||
"Range",
|
||||
names[len(inputs) - 1],
|
||||
["Y"]
|
||||
)
|
||||
inputs = (np.array(0), np.array(10), np.array(0))
|
||||
op = core.CreateOperator(
|
||||
"Range",
|
||||
names[len(inputs) - 1],
|
||||
["Y"]
|
||||
)
|
||||
with six.assertRaisesRegex(self, RuntimeError, 'Step size cannot be 0'):
|
||||
self.assertReferenceChecks(
|
||||
device_option=gc,
|
||||
op=op,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import caffe2.python.hypothesis_test_util as hu
|
|||
from hypothesis import given
|
||||
import hypothesis.strategies as st
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
|
||||
def SubFunctionThatThrowsRuntimeError():
|
||||
|
|
@ -48,9 +49,7 @@ class PythonOpTest(hu.HypothesisTestCase):
|
|||
|
||||
def test_exception(self):
|
||||
op = CreatePythonOperator(MainOpFunctionThatThrowsRuntimeError, [], [])
|
||||
with self.assertRaisesRegexp(
|
||||
RuntimeError, "This is an intentional exception."
|
||||
):
|
||||
with six.assertRaisesRegex(self, RuntimeError, "This is an intentional exception."):
|
||||
workspace.RunOperatorOnce(op)
|
||||
|
||||
@given(x=hu.tensor())
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from caffe2.python.models.download import downloadFromURLToFile, getURLFromName,
|
|||
import caffe2.python.onnx.backend as c2
|
||||
from caffe2.python.onnx.workspace import Workspace
|
||||
from caffe2.python.trt.transform import convert_onnx_model_to_trt_op, transform_caffe2_net
|
||||
from caffe2.python.onnx.tests.test_utils import TestCase
|
||||
from caffe2.python.onnx.tests.test_utils import TestCase, DownloadingTestCase
|
||||
import numpy as np
|
||||
import os.path
|
||||
import json
|
||||
|
|
@ -170,34 +170,13 @@ class TensorRTOpTest(TestCase):
|
|||
def test_vgg19(self):
|
||||
self._test_onnx_importer('vgg19', -1)
|
||||
|
||||
class TensorRTTransformTest(TestCase):
|
||||
|
||||
class TensorRTTransformTest(DownloadingTestCase):
|
||||
def _model_dir(self, model):
|
||||
caffe2_home = os.path.expanduser(os.getenv('CAFFE2_HOME', '~/.caffe2'))
|
||||
models_dir = os.getenv('CAFFE2_MODELS', os.path.join(caffe2_home, 'models'))
|
||||
return os.path.join(models_dir, model)
|
||||
|
||||
def _download(self, model):
|
||||
model_dir = self._model_dir(model)
|
||||
assert not os.path.exists(model_dir)
|
||||
os.makedirs(model_dir)
|
||||
for f in ['predict_net.pb', 'init_net.pb', 'value_info.json']:
|
||||
url = getURLFromName(model, f)
|
||||
dest = os.path.join(model_dir, f)
|
||||
try:
|
||||
try:
|
||||
downloadFromURLToFile(url, dest,
|
||||
show_progress=False)
|
||||
except TypeError:
|
||||
# show_progress not supported prior to
|
||||
# Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1
|
||||
# (Sep 17, 2017)
|
||||
downloadFromURLToFile(url, dest)
|
||||
except Exception as e:
|
||||
print("Abort: {reason}".format(reason=e))
|
||||
print("Cleaning up...")
|
||||
deleteDirectory(model_dir)
|
||||
exit(1)
|
||||
|
||||
def _get_c2_model(self, model_name):
|
||||
model_dir = self._model_dir(model_name)
|
||||
if not os.path.exists(model_dir):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user