Core unit test fixes for Python 3

Summary: As title

Differential Revision: D5291327

fbshipit-source-id: 7dd9279c53ba55d3422c31973ffcec5705787fdf
This commit is contained in:
Thomas Dudziak 2017-06-23 13:06:20 -07:00 committed by Facebook Github Bot
parent ff914bf201
commit 342de07231
18 changed files with 167 additions and 131 deletions

View File

@ -8,13 +8,7 @@ from __future__ import unicode_literals
import sys import sys
import copy import copy
import inspect import inspect
try: from past.builtins import basestring
from past.builtins import basestring
except ImportError:
print("You don't have the past package installed. ",
"This is necessary for python 2/3 compatibility. ",
"To do this, do 'pip install future'.")
sys.exit(1)
from caffe2.python.model_helper import ModelHelper from caffe2.python.model_helper import ModelHelper
# flake8: noqa # flake8: noqa

View File

@ -6,14 +6,8 @@ from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
try: from past.builtins import basestring
from past.builtins import basestring from six import binary_type, string_types, text_type
except ImportError:
print("You don't have the past package installed. ",
"This is necessary for python 2/3 compatibility. ",
"To do this, do 'pip install future'.")
import sys
sys.exit(1)
from caffe2.proto import caffe2_pb2 from caffe2.proto import caffe2_pb2
from collections import defaultdict from collections import defaultdict
@ -138,7 +132,12 @@ class BlobReference(object):
Note that this does not prepends the namescope. If needed, use Note that this does not prepends the namescope. If needed, use
ScopedBlobReference() to prepend the existing namespace. ScopedBlobReference() to prepend the existing namespace.
""" """
if isinstance(name, string_types):
self._name = name self._name = name
elif isinstance(name, binary_type):
self._name = name.decode('utf-8')
else:
self._name = str(name)
self._from_net = net self._from_net = net
# meta allows helper functions to put whatever metainformation needed # meta allows helper functions to put whatever metainformation needed
# there. # there.
@ -148,8 +147,10 @@ class BlobReference(object):
return hash(self._name) return hash(self._name)
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, basestring): if isinstance(other, string_types):
return self._name == other return self._name == other
elif isinstance(other, binary_type):
return self._name == other.decode('utf-8')
elif isinstance(other, BlobReference): elif isinstance(other, BlobReference):
return self._name == other._name return self._name == other._name
else: else:
@ -165,12 +166,12 @@ class BlobReference(object):
return 'BlobReference("{}")'.format(self._name) return 'BlobReference("{}")'.format(self._name)
def __add__(self, other): def __add__(self, other):
if not isinstance(other, basestring): if not isinstance(other, string_types):
raise RuntimeError('Cannot add BlobReference to a non-string.') raise RuntimeError('Cannot add BlobReference to a non-string.')
return BlobReference(self._name + other, self._from_net) return BlobReference(self._name + other, self._from_net)
def __radd__(self, other): def __radd__(self, other):
if not isinstance(other, basestring): if not isinstance(other, string_types):
raise RuntimeError('Cannot add a non-string to BlobReference.') raise RuntimeError('Cannot add a non-string to BlobReference.')
return BlobReference(other + self._name, self._from_net) return BlobReference(other + self._name, self._from_net)
@ -185,7 +186,7 @@ class BlobReference(object):
network's __getattr__ function. network's __getattr__ function.
""" """
inputs = [] if inputs is None else inputs inputs = [] if inputs is None else inputs
if isinstance(inputs, BlobReference) or isinstance(inputs, str): if isinstance(inputs, BlobReference) or isinstance(inputs, string_types):
inputs = [inputs] inputs = [inputs]
# add self to the input list. # add self to the input list.
inputs.insert(0, self) inputs.insert(0, self)
@ -228,7 +229,7 @@ class BlobReference(object):
def ScopedName(name): def ScopedName(name):
"""prefix the name with the current scope.""" """prefix the name with the current scope."""
if isinstance(name, bytes): if isinstance(name, binary_type):
name = name.decode('ascii') name = name.decode('ascii')
return scope.CurrentNameScope() + name return scope.CurrentNameScope() + name
@ -242,7 +243,7 @@ def _RectifyInputOutput(blobs, net=None):
"""A helper function to rectify the input or output of the CreateOperator """A helper function to rectify the input or output of the CreateOperator
interface. interface.
""" """
if isinstance(blobs, basestring): if isinstance(blobs, string_types):
# If blobs is a single string, prepend scope.CurrentNameScope() # If blobs is a single string, prepend scope.CurrentNameScope()
# and put it as a list. # and put it as a list.
# TODO(jiayq): enforce using BlobReference instead of raw strings. # TODO(jiayq): enforce using BlobReference instead of raw strings.
@ -254,7 +255,7 @@ def _RectifyInputOutput(blobs, net=None):
# If blob is a list, we go through it and type check. # If blob is a list, we go through it and type check.
rectified = [] rectified = []
for blob in blobs: for blob in blobs:
if isinstance(blob, basestring): if isinstance(blob, string_types) or isinstance(blob, binary_type):
rectified.append(ScopedBlobReference(blob, net=net)) rectified.append(ScopedBlobReference(blob, net=net))
elif type(blob) is BlobReference: elif type(blob) is BlobReference:
rectified.append(blob) rectified.append(blob)
@ -291,11 +292,11 @@ def CreateOperator(
# Add rectified inputs and outputs # Add rectified inputs and outputs
inputs = _RectifyInputOutput(inputs) inputs = _RectifyInputOutput(inputs)
outputs = _RectifyInputOutput(outputs) outputs = _RectifyInputOutput(outputs)
operator.input.extend([str(i) for i in inputs]) operator.input.extend([text_type(i) for i in inputs])
operator.output.extend([str(o) for o in outputs]) operator.output.extend([text_type(o) for o in outputs])
if control_input: if control_input:
control_input = _RectifyInputOutput(control_input) control_input = _RectifyInputOutput(control_input)
operator.control_input.extend([str(i) for i in control_input]) operator.control_input.extend([text_type(i) for i in control_input])
# Set device option: # Set device option:
# (1) If device_option is explicitly set, use device_option. # (1) If device_option is explicitly set, use device_option.
# (2) If not, but scope.CurrentDeviceScope() is set, # (2) If not, but scope.CurrentDeviceScope() is set,
@ -926,9 +927,13 @@ class IR(object):
all_input_to_grad_out = {} all_input_to_grad_out = {}
for key, val in all_input_to_grad.items(): for key, val in all_input_to_grad.items():
if val is not None: if val is not None:
all_input_to_grad_out[BlobReference(key)] = ( if (isinstance(val, string_types) or
BlobReference(val) if isinstance(val, basestring) else isinstance(val, binary_type)):
GradientSlice(BlobReference(val[0]), BlobReference(val[1]))) grad_out = BlobReference(val)
else:
grad_out = GradientSlice(BlobReference(val[0]),
BlobReference(val[1]))
all_input_to_grad_out[BlobReference(key)] = grad_out
return all_gradient_ops, all_input_to_grad_out return all_gradient_ops, all_input_to_grad_out
@ -1599,7 +1604,7 @@ class Net(object):
def _ExtendOps(self, new_ops): def _ExtendOps(self, new_ops):
self._net.op.extend(new_ops) self._net.op.extend(new_ops)
for op in new_ops: for op in new_ops:
self._op_outputs.update([str(o) for o in op.output]) self._op_outputs.update([text_type(o) for o in op.output])
def _CheckLookupTables(self): def _CheckLookupTables(self):
''' '''
@ -1691,7 +1696,7 @@ class Net(object):
def AddScopedExternalInputs(self, *inputs): def AddScopedExternalInputs(self, *inputs):
res = self.AddExternalInput( res = self.AddExternalInput(
* [ScopedBlobReference(str(b)) for b in inputs] * [ScopedBlobReference(b) for b in inputs]
) )
if not isinstance(res, list): if not isinstance(res, list):
res = [res] res = [res]
@ -1699,7 +1704,7 @@ class Net(object):
def AddScopedExternalOutputs(self, *outputs): def AddScopedExternalOutputs(self, *outputs):
return self.AddExternalOutput( return self.AddExternalOutput(
* [ScopedBlobReference(str(b)) for b in outputs] * [ScopedBlobReference(b) for b in outputs]
) )
@property @property
@ -1826,9 +1831,9 @@ class Net(object):
if len(op.output) == 0: if len(op.output) == 0:
return return
elif len(op.output) == 1: elif len(op.output) == 1:
return BlobReference(str(op.output[0]), self) return BlobReference(op.output[0], self)
else: else:
return tuple(BlobReference(str(o), self) for o in op.output) return tuple(BlobReference(o, self) for o in op.output)
def __getattr__(self, op_type): def __getattr__(self, op_type):
if op_type.startswith('__'): if op_type.startswith('__'):

View File

@ -1114,7 +1114,7 @@ def _InferBlobDevice(model):
step_args = [a for a in op.arg if a.name.endswith("step_net")] step_args = [a for a in op.arg if a.name.endswith("step_net")]
for step_arg in step_args: for step_arg in step_args:
step_proto = caffe2_pb2.NetDef() step_proto = caffe2_pb2.NetDef()
protobuftx.Merge(step_arg.s, step_proto) protobuftx.Merge(step_arg.s.decode("ascii"), step_proto)
map_ops(step_proto) map_ops(step_proto)
map_ops(model.net.Proto()) map_ops(model.net.Proto())
model._blob_to_device = mapping model._blob_to_device = mapping

View File

@ -453,7 +453,8 @@ class SparseDataParallelModelTest(TestCase):
workspace.RunNet(model.net.Proto().name) workspace.RunNet(model.net.Proto().name)
if len(gpu_devices) == 2: if len(gpu_devices) == 2:
open("dump.txt", "w").write(str(model.net.Proto())) with open("/tmp/dump.txt", "w") as f:
f.write(str(model.net.Proto()))
if not cpu_indices: if not cpu_indices:
idx = workspace.FetchBlob("gpu_0/indices") idx = workspace.FetchBlob("gpu_0/indices")
idx = list(idx.flatten()) idx = list(idx.flatten())

View File

@ -812,10 +812,11 @@ class TestOperators(hu.HypothesisTestCase):
N = prediction.shape[0] N = prediction.shape[0]
correct = 0 correct = 0
for i in range(0, len(prediction)): for i in range(0, len(prediction)):
# we no longer have cmp function in python 3 pred_sorted = sorted(
pred_sorted = sorted([ ([item, j] for j, item in enumerate(prediction[i])),
[item, j] for j, item in enumerate(prediction[i])], key=lambda x: x[0],
cmp=lambda x, y: int(y[0] > x[0]) - int(y[0] < x[0])) reverse=True
)
max_ids = [x[1] for x in pred_sorted[0:top_k]] max_ids = [x[1] for x in pred_sorted[0:top_k]]
for m in max_ids: for m in max_ids:
if m == labels[i]: if m == labels[i]:
@ -889,7 +890,7 @@ class TestOperators(hu.HypothesisTestCase):
def op_ref(lengths): def op_ref(lengths):
sids = [] sids = []
for _, l in enumerate(lengths): for _, l in enumerate(lengths):
sids.extend(range(l)) sids.extend(list(range(l)))
return (np.array(sids, dtype=np.int32), ) return (np.array(sids, dtype=np.int32), )
self.assertReferenceChecks( self.assertReferenceChecks(
@ -1122,7 +1123,11 @@ class TestOperators(hu.HypothesisTestCase):
original matrices. original matrices.
""" """
import threading import threading
import Queue try:
import queue
except ImportError:
# Py3
import Queue as queue
op = core.CreateOperator( op = core.CreateOperator(
"CreateBlobsQueue", "CreateBlobsQueue",
[], [],
@ -1134,7 +1139,7 @@ class TestOperators(hu.HypothesisTestCase):
xs = [np.random.randn(num_elements, 5).astype(np.float32) xs = [np.random.randn(num_elements, 5).astype(np.float32)
for _ in range(num_blobs)] for _ in range(num_blobs)]
q = Queue.Queue() q = queue.Queue()
for i in range(num_elements): for i in range(num_elements):
q.put([x[i] for x in xs]) q.put([x[i] for x in xs])
@ -1152,7 +1157,7 @@ class TestOperators(hu.HypothesisTestCase):
self.ws.create_blob(feed_blob).feed( self.ws.create_blob(feed_blob).feed(
elem, device_option=do) elem, device_option=do)
self.ws.run(op) self.ws.run(op)
except Queue.Empty: except queue.Empty:
return return
# Create all blobs before racing on multiple threads # Create all blobs before racing on multiple threads
@ -1840,7 +1845,7 @@ class TestOperators(hu.HypothesisTestCase):
backward_link_internal=backward_link_internal, backward_link_internal=backward_link_internal,
backward_link_external=backward_link_external, backward_link_external=backward_link_external,
backward_link_offset=backward_link_offset, backward_link_offset=backward_link_offset,
param=map(inputs.index, step_net.params), param=[inputs.index(p) for p in step_net.params],
step_net=str(step_net.Proto()), step_net=str(step_net.Proto()),
backward_step_net=str(backward_step_net.Proto()), backward_step_net=str(backward_step_net.Proto()),
outputs_with_grads=[0], outputs_with_grads=[0],

View File

@ -104,10 +104,13 @@ class LayersTestCase(test_util.TestCase):
def assertArgsEqual(self, spec_args, op_args): def assertArgsEqual(self, spec_args, op_args):
self.assertEqual(len(spec_args), len(op_args)) self.assertEqual(len(spec_args), len(op_args))
keys = [a.name for a in op_args]
def parse_args(args): def parse_args(args):
operator = caffe2_pb2.OperatorDef() operator = caffe2_pb2.OperatorDef()
for k, v in args.items(): # Generate the expected value in the same order
for k in keys:
v = args[k]
arg = utils.MakeArgument(k, v) arg = utils.MakeArgument(k, v)
operator.arg.add().CopyFrom(arg) operator.arg.add().CopyFrom(arg)
return operator.arg return operator.arg

View File

@ -852,18 +852,18 @@ def apply_recurrent_blob_assignments(op, blob_assignments, canonical_name):
step_args = [a for a in op.arg if a.name.endswith("step_net")] step_args = [a for a in op.arg if a.name.endswith("step_net")]
for step_arg in step_args: for step_arg in step_args:
step_proto = caffe2_pb2.NetDef() step_proto = caffe2_pb2.NetDef()
protobuftx.Merge(step_arg.s, step_proto) protobuftx.Merge(step_arg.s.decode("ascii"), step_proto)
apply_assignments(step_proto, blob_assignments) apply_assignments(step_proto, blob_assignments)
for i, einp in enumerate(step_proto.external_input): for i, einp in enumerate(step_proto.external_input):
if einp in blob_assignments: if einp in blob_assignments:
step_proto.external_input[i] = canonical_name(einp) step_proto.external_input[i] = canonical_name(einp)
step_arg.s = str(step_proto) step_arg.s = str(step_proto).encode("ascii")
# Store renamings # Store renamings
for blob, renamed in blob_assignments.items(): for blob, renamed in blob_assignments.items():
if blob in list(op.input) + list(op.output): if blob in list(op.input) + list(op.output):
a = caffe2_pb2.Argument() a = caffe2_pb2.Argument()
a.name = blob + ".rename" a.name = blob + ".rename"
a.s = str(renamed) a.s = str(renamed).encode("ascii")
op.arg.extend([a]) op.arg.extend([a])

View File

@ -531,10 +531,10 @@ def ExtractPredictorNet(
import google.protobuf.text_format as protobuftx import google.protobuf.text_format as protobuftx
for arg in op.arg: for arg in op.arg:
if arg.name == 'backward_step_net': if arg.name == 'backward_step_net':
arg.s = str("") arg.s = b""
elif arg.name == 'step_net': elif arg.name == 'step_net':
step_proto = caffe2_pb2.NetDef() step_proto = caffe2_pb2.NetDef()
protobuftx.Merge(arg.s, step_proto) protobuftx.Merge(arg.s.decode("ascii"), step_proto)
for step_op in step_proto.op: for step_op in step_proto.op:
if device is not None: if device is not None:
step_op.device_option.device_type = device.device_type step_op.device_option.device_type = device.device_type
@ -546,7 +546,7 @@ def ExtractPredictorNet(
orig_external_inputs orig_external_inputs
) )
) )
arg.s = str(step_proto) arg.s = str(step_proto).encode("ascii")
if device is not None: if device is not None:
op.device_option.device_type = device.device_type op.device_option.device_type = device.device_type

View File

@ -20,7 +20,7 @@ import hypothesis.strategies as st
def _assert_arrays_equal(actual, ref, err_msg): def _assert_arrays_equal(actual, ref, err_msg):
if ref.dtype.kind in ('S', 'O'): if ref.dtype.kind in ('S', 'O', 'U'):
np.testing.assert_array_equal(actual, ref, err_msg=err_msg) np.testing.assert_array_equal(actual, ref, err_msg=err_msg)
else: else:
np.testing.assert_allclose( np.testing.assert_allclose(

View File

@ -119,7 +119,8 @@ class TestHsm(hu.HypothesisTestCase):
for i in range(names.shape[0]): for i in range(names.shape[0]):
for j in range(names.shape[1]): for j in range(names.shape[1]):
if names[i][j]: if names[i][j]:
assert(names[i][j] == p_names[i][j]) self.assertEquals(
names[i][j], p_names[i][j].item().encode('utf-8'))
self.assertAlmostEqual( self.assertAlmostEqual(
scores[i][j], p_scores[i][j], delta=0.001) scores[i][j], p_scores[i][j], delta=0.001)

View File

@ -7,13 +7,17 @@ import unittest
try: try:
import cv2 import cv2
except ImportError: except ImportError:
raise unittest.SkipTest('python-opencv is not installed') pass # Handled below
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import lmdb import lmdb
import shutil import shutil
import StringIO try:
import StringIO
except ImportError:
from io import StringIO
import sys
import tempfile import tempfile
# TODO: This test does not test scaling because # TODO: This test does not test scaling because
@ -194,6 +198,7 @@ def create_test(output_dir, width, height, default_bound,
return expected_results return expected_results
@unittest.skipIf('cv2' not in sys.modules, 'python-opencv is not installed')
class TestImport(hu.HypothesisTestCase): class TestImport(hu.HypothesisTestCase):
@given(size_tuple=st.tuples( @given(size_tuple=st.tuples(
st.integers(min_value=8, max_value=4096), st.integers(min_value=8, max_value=4096),

View File

@ -64,8 +64,14 @@ class TestIndexOps(TestCase):
['stored_entries'])) ['stored_entries']))
stored_actual = workspace.FetchBlob('stored_entries') stored_actual = workspace.FetchBlob('stored_entries')
new_entries = np.array([entries[3], entries[4]], dtype=dtype) new_entries = np.array([entries[3], entries[4]], dtype=dtype)
np.testing.assert_array_equal( expected = np.concatenate((my_entries, new_entries))
np.concatenate((my_entries, new_entries)), stored_actual) if dtype is str:
# we'll always get bytes back from Caffe2
expected = np.array([
x.item().encode('utf-8') if isinstance(x, np.str_) else x
for x in expected
], dtype=object)
np.testing.assert_array_equal(expected, stored_actual)
workspace.RunOperatorOnce(core.CreateOperator( workspace.RunOperatorOnce(core.CreateOperator(
index_create_op, index_create_op,

View File

@ -458,10 +458,9 @@ def prepare_mul_rnn(model, input_blob, shape, T, outputs_with_grad, num_layers):
model=model, model=model,
inputs=input_blob, inputs=input_blob,
initial_states=states, initial_states=states,
outputs_with_grads=map( outputs_with_grads=[
lambda x: x + 2 * (num_layers - 1), x + 2 * (num_layers - 1) for x in outputs_with_grad
outputs_with_grad ],
),
seq_lengths=None, seq_lengths=None,
) )
return results[-2:] return results[-2:]
@ -682,12 +681,12 @@ class RNNCellTest(hu.HypothesisTestCase):
for arg in op.arg: for arg in op.arg:
if arg.name == "step_net": if arg.name == "step_net":
step_proto = caffe2_pb2.NetDef() step_proto = caffe2_pb2.NetDef()
protobuftx.Merge(arg.s, step_proto) protobuftx.Merge(arg.s.decode("ascii"), step_proto)
for step_op in step_proto.op: for step_op in step_proto.op:
self.assertEqual(0, step_op.device_option.device_type) self.assertEqual(0, step_op.device_option.device_type)
self.assertEqual(1, step_op.device_option.cuda_gpu_id) self.assertEqual(1, step_op.device_option.cuda_gpu_id)
elif arg.name == 'backward_step_net': elif arg.name == 'backward_step_net':
self.assertEqual("", arg.s) self.assertEqual(b"", arg.s)
@given(encoder_output_length=st.integers(1, 3), @given(encoder_output_length=st.integers(1, 3),

View File

@ -19,7 +19,10 @@ class TestCounterOps(TestCase):
existing = len(previous_keys) existing = len(previous_keys)
prefix = '/'.join([__name__, 'TestCounterOps', 'test_stats_ops']) prefix = '/'.join([__name__, 'TestCounterOps', 'test_stats_ops'])
keys = [prefix + '/key1', prefix + '/key2'] keys = [
(prefix + '/key1').encode('ascii'),
(prefix + '/key2').encode('ascii')
]
values = [34, 45] values = [34, 45]
workspace.FeedBlob('k', np.array(keys, dtype=str)) workspace.FeedBlob('k', np.array(keys, dtype=str))
workspace.FeedBlob('v', np.array(values, dtype=np.int64)) workspace.FeedBlob('v', np.array(values, dtype=np.int64))

View File

@ -8,7 +8,6 @@ from caffe2.python.test_util import TestCase
from caffe2.python.schema import Struct, Scalar, FetchRecord from caffe2.python.schema import Struct, Scalar, FetchRecord
import tempfile import tempfile
import numpy as np import numpy as np
import os
class TestTextFileReader(TestCase): class TestTextFileReader(TestCase):
@ -24,14 +23,14 @@ class TestTextFileReader(TestCase):
[0.456, 0.789, 0.10101, -24342.64], [0.456, 0.789, 0.10101, -24342.64],
] ]
row_data = list(zip(*col_data)) row_data = list(zip(*col_data))
txt_file = tempfile.NamedTemporaryFile(delete=False) with tempfile.NamedTemporaryFile(mode='w+', delete=False) as txt_file:
txt_file.write( txt_file.write(
'\n'.join( '\n'.join(
'\t'.join(str(x) for x in f) '\t'.join(str(x) for x in f)
for f in row_data for f in row_data
) + '\n' ) + '\n'
) )
txt_file.close() txt_file.flush()
for num_passes in range(1, 3): for num_passes in range(1, 3):
for batch_size in range(1, len(row_data) + 2): for batch_size in range(1, len(row_data) + 2):
@ -63,8 +62,6 @@ class TestTextFileReader(TestCase):
else: else:
np.testing.assert_array_equal(col_batch, results[i]) np.testing.assert_array_equal(col_batch, results[i])
os.remove(txt_file.name)
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest
unittest.main() unittest.main()

View File

@ -231,7 +231,7 @@ PythonOpBase::PythonOpBase(
CAFFE_ENFORCE(pickle); CAFFE_ENFORCE(pickle);
auto loads = pickle.attr("loads").cast<py::object>(); auto loads = pickle.attr("loads").cast<py::object>();
CAFFE_ENFORCE(loads); CAFFE_ENFORCE(loads);
auto builder_call = loads(pickled).cast<py::tuple>(); auto builder_call = loads(py::bytes(pickled)).cast<py::tuple>();
CAFFE_ENFORCE(builder_call); CAFFE_ENFORCE(builder_call);
CAFFE_ENFORCE_EQ(py::len(builder_call), 3); CAFFE_ENFORCE_EQ(py::len(builder_call), 3);
auto func = builder_call[0].cast<py::object>(); auto func = builder_call[0].cast<py::object>();
@ -504,7 +504,8 @@ void addObjectMethods(py::module& m) {
"_create_net", "_create_net",
[](Workspace* self, py::bytes def, bool overwrite) -> py::object { [](Workspace* self, py::bytes def, bool overwrite) -> py::object {
caffe2::NetDef proto; caffe2::NetDef proto;
CAFFE_ENFORCE(ParseProtobufFromLargeString(def, &proto)); CAFFE_ENFORCE(
ParseProtobufFromLargeString(def.cast<std::string>(), &proto));
auto* net = self->CreateNet(proto, overwrite); auto* net = self->CreateNet(proto, overwrite);
CAFFE_ENFORCE(net); CAFFE_ENFORCE(net);
return py::cast(net, py::return_value_policy::reference_internal); return py::cast(net, py::return_value_policy::reference_internal);
@ -527,7 +528,8 @@ void addObjectMethods(py::module& m) {
"_run_net", "_run_net",
[](Workspace* self, py::bytes def) { [](Workspace* self, py::bytes def) {
caffe2::NetDef proto; caffe2::NetDef proto;
CAFFE_ENFORCE(ParseProtobufFromLargeString(def, &proto)); CAFFE_ENFORCE(
ParseProtobufFromLargeString(def.cast<std::string>(), &proto));
py::gil_scoped_release g; py::gil_scoped_release g;
CAFFE_ENFORCE(self->RunNetOnce(proto)); CAFFE_ENFORCE(self->RunNetOnce(proto));
}) })
@ -535,7 +537,8 @@ void addObjectMethods(py::module& m) {
"_run_operator", "_run_operator",
[](Workspace* self, py::bytes def) { [](Workspace* self, py::bytes def) {
caffe2::OperatorDef proto; caffe2::OperatorDef proto;
CAFFE_ENFORCE(ParseProtobufFromLargeString(def, &proto)); CAFFE_ENFORCE(
ParseProtobufFromLargeString(def.cast<std::string>(), &proto));
py::gil_scoped_release g; py::gil_scoped_release g;
CAFFE_ENFORCE(self->RunOperatorOnce(proto)); CAFFE_ENFORCE(self->RunOperatorOnce(proto));
}) })
@ -543,7 +546,8 @@ void addObjectMethods(py::module& m) {
"_run_plan", "_run_plan",
[](Workspace* self, py::bytes def) { [](Workspace* self, py::bytes def) {
caffe2::PlanDef proto; caffe2::PlanDef proto;
CAFFE_ENFORCE(ParseProtobufFromLargeString(def, &proto)); CAFFE_ENFORCE(
ParseProtobufFromLargeString(def.cast<std::string>(), &proto));
py::gil_scoped_release g; py::gil_scoped_release g;
CAFFE_ENFORCE(self->RunPlan(proto)); CAFFE_ENFORCE(self->RunPlan(proto));
}) })
@ -568,7 +572,8 @@ void addObjectMethods(py::module& m) {
"get_gradient_defs", "get_gradient_defs",
[](py::bytes op_def, std::vector<GradientWrapper> output_gradients) { [](py::bytes op_def, std::vector<GradientWrapper> output_gradients) {
OperatorDef def; OperatorDef def;
CAFFE_ENFORCE(ParseProtobufFromLargeString(op_def, &def)); CAFFE_ENFORCE(
ParseProtobufFromLargeString(op_def.cast<std::string>(), &def));
CAFFE_ENFORCE(caffe2::GradientRegistry()->Has(def.type())); CAFFE_ENFORCE(caffe2::GradientRegistry()->Has(def.type()));
const auto& meta = GetGradientForOp(def, output_gradients); const auto& meta = GetGradientForOp(def, output_gradients);
std::vector<py::bytes> grad_ops; std::vector<py::bytes> grad_ops;
@ -639,9 +644,10 @@ void addObjectMethods(py::module& m) {
[](Predictor& instance, py::bytes init_net, py::bytes predict_net) { [](Predictor& instance, py::bytes init_net, py::bytes predict_net) {
CAFFE_ENFORCE(gWorkspace); CAFFE_ENFORCE(gWorkspace);
NetDef init_net_, predict_net_; NetDef init_net_, predict_net_;
CAFFE_ENFORCE(ParseProtobufFromLargeString(init_net, &init_net_)); CAFFE_ENFORCE(ParseProtobufFromLargeString(
CAFFE_ENFORCE( init_net.cast<std::string>(), &init_net_));
ParseProtobufFromLargeString(predict_net, &predict_net_)); CAFFE_ENFORCE(ParseProtobufFromLargeString(
predict_net.cast<std::string>(), &predict_net_));
new (&instance) Predictor(init_net_, predict_net_, gWorkspace); new (&instance) Predictor(init_net_, predict_net_, gWorkspace);
}) })
.def( .def(
@ -781,15 +787,16 @@ void addGlobalMethods(py::module& m) {
m.def( m.def(
"create_net", "create_net",
[](py::bytes net_def, bool overwrite) { [](py::bytes net_def, bool overwrite) {
CAFFE_ENFORCE(gWorkspace);
caffe2::NetDef proto; caffe2::NetDef proto;
CAFFE_ENFORCE( CAFFE_ENFORCE(
ParseProtobufFromLargeString(net_def, &proto), ParseProtobufFromLargeString(net_def.cast<std::string>(), &proto),
"Can't parse net proto: ", "Can't parse net proto: ",
std::string(net_def)); net_def.cast<std::string>());
CAFFE_ENFORCE( CAFFE_ENFORCE(
gWorkspace->CreateNet(proto, overwrite), gWorkspace->CreateNet(proto, overwrite),
"Error creating net with proto: ", "Error creating net with proto: ",
std::string(net_def)); net_def.cast<std::string>());
return true; return true;
}, },
py::arg("net_def"), py::arg("net_def"),
@ -834,7 +841,8 @@ void addGlobalMethods(py::module& m) {
m.def("run_operator_once", [](const py::bytes& op_def) { m.def("run_operator_once", [](const py::bytes& op_def) {
CAFFE_ENFORCE(gWorkspace); CAFFE_ENFORCE(gWorkspace);
OperatorDef def; OperatorDef def;
CAFFE_ENFORCE(ParseProtobufFromLargeString(op_def, &def)); CAFFE_ENFORCE(
ParseProtobufFromLargeString(op_def.cast<std::string>(), &def));
py::gil_scoped_release g; py::gil_scoped_release g;
CAFFE_ENFORCE(gWorkspace->RunOperatorOnce(def)); CAFFE_ENFORCE(gWorkspace->RunOperatorOnce(def));
return true; return true;
@ -842,16 +850,17 @@ void addGlobalMethods(py::module& m) {
m.def("run_net_once", [](const py::bytes& net_def) { m.def("run_net_once", [](const py::bytes& net_def) {
CAFFE_ENFORCE(gWorkspace); CAFFE_ENFORCE(gWorkspace);
NetDef def; NetDef def;
CAFFE_ENFORCE(ParseProtobufFromLargeString(net_def, &def)); CAFFE_ENFORCE(
ParseProtobufFromLargeString(net_def.cast<std::string>(), &def));
py::gil_scoped_release g; py::gil_scoped_release g;
CAFFE_ENFORCE(gWorkspace->RunNetOnce(def)); CAFFE_ENFORCE(gWorkspace->RunNetOnce(def));
return true; return true;
}); });
m.def("run_plan", [](const py::bytes& plan_def) { m.def("run_plan", [](const py::bytes& plan_def) {
CAFFE_ENFORCE(gWorkspace); CAFFE_ENFORCE(gWorkspace);
const std::string& msg = std::move(plan_def);
PlanDef def; PlanDef def;
CAFFE_ENFORCE(ParseProtobufFromLargeString(msg, &def)); CAFFE_ENFORCE(
ParseProtobufFromLargeString(plan_def.cast<std::string>(), &def));
py::gil_scoped_release g; py::gil_scoped_release g;
CAFFE_ENFORCE(gWorkspace->RunPlan(def)); CAFFE_ENFORCE(gWorkspace->RunPlan(def));
return true; return true;

View File

@ -11,7 +11,7 @@ import sys
import collections import collections
import functools import functools
import numpy as np import numpy as np
from six import integer_types, string_types, text_type from six import integer_types, binary_type, text_type
def CaffeBlobToNumpyArray(blob): def CaffeBlobToNumpyArray(blob):
@ -80,9 +80,10 @@ def MakeArgument(key, value):
# We make a relaxation that a boolean variable will also be stored as # We make a relaxation that a boolean variable will also be stored as
# int. # int.
argument.i = value argument.i = value
elif isinstance(value, string_types): elif isinstance(value, binary_type):
argument.s = (value.encode('utf-8') if isinstance(value, text_type) argument.s = value
else value) elif isinstance(value, text_type):
argument.s = value.encode('utf-8')
elif isinstance(value, Message): elif isinstance(value, Message):
argument.s = value.SerializeToString() argument.s = value.SerializeToString()
elif iterable and all(type(v) in [float, np.float_] for v in value): elif iterable and all(type(v) in [float, np.float_] for v in value):
@ -95,7 +96,9 @@ def MakeArgument(key, value):
argument.ints.extend( argument.ints.extend(
v.item() if type(v) is np.int_ else v for v in value v.item() if type(v) is np.int_ else v for v in value
) )
elif iterable and all(isinstance(v, string_types) for v in value): elif iterable and all(
isinstance(v, binary_type) or isinstance(v, text_type) for v in value
):
argument.strings.extend( argument.strings.extend(
v.encode('utf-8') if isinstance(v, text_type) else v v.encode('utf-8') if isinstance(v, text_type) else v
for v in value for v in value
@ -103,9 +106,18 @@ def MakeArgument(key, value):
elif iterable and all(isinstance(v, Message) for v in value): elif iterable and all(isinstance(v, Message) for v in value):
argument.strings.extend(v.SerializeToString() for v in value) argument.strings.extend(v.SerializeToString() for v in value)
else: else:
if iterable:
raise ValueError( raise ValueError(
"Unknown argument type: key=%s value=%s, value type=%s" % "Unknown iterable argument type: key={} value={}, value "
(key, str(value), str(type(value))) "type={}[{}]".format(
key, value, type(value), set(type(v) for v in value)
)
)
else:
raise ValueError(
"Unknown argument type: key={} value={}, value type={}".format(
key, value, type(value)
)
) )
return argument return argument

View File

@ -385,7 +385,6 @@ class TestCWorkspace(htu.HypothesisTestCase):
@given(name=st.text(), value=st.floats(min_value=-1, max_value=1.0)) @given(name=st.text(), value=st.floats(min_value=-1, max_value=1.0))
def test_operator_run(self, name, value): def test_operator_run(self, name, value):
name = name.encode('ascii', 'ignore')
ws = workspace.C.Workspace() ws = workspace.C.Workspace()
op = core.CreateOperator( op = core.CreateOperator(
"ConstantFill", [], [name], shape=[1], value=value) "ConstantFill", [], [name], shape=[1], value=value)
@ -398,7 +397,6 @@ class TestCWorkspace(htu.HypothesisTestCase):
net_name=st.text(), net_name=st.text(),
value=st.floats(min_value=-1, max_value=1.0)) value=st.floats(min_value=-1, max_value=1.0))
def test_net_run(self, blob_name, net_name, value): def test_net_run(self, blob_name, net_name, value):
blob_name = blob_name.encode('ascii', 'ignore')
ws = workspace.C.Workspace() ws = workspace.C.Workspace()
net = core.Net(net_name) net = core.Net(net_name)
net.ConstantFill([], [blob_name], shape=[1], value=value) net.ConstantFill([], [blob_name], shape=[1], value=value)
@ -413,7 +411,6 @@ class TestCWorkspace(htu.HypothesisTestCase):
plan_name=st.text(), plan_name=st.text(),
value=st.floats(min_value=-1, max_value=1.0)) value=st.floats(min_value=-1, max_value=1.0))
def test_plan_run(self, blob_name, plan_name, net_name, value): def test_plan_run(self, blob_name, plan_name, net_name, value):
blob_name = blob_name.encode('ascii', 'ignore')
ws = workspace.C.Workspace() ws = workspace.C.Workspace()
plan = core.Plan(plan_name) plan = core.Plan(plan_name)
net = core.Net(net_name) net = core.Net(net_name)
@ -431,7 +428,6 @@ class TestCWorkspace(htu.HypothesisTestCase):
net_name=st.text(), net_name=st.text(),
value=st.floats(min_value=-1, max_value=1.0)) value=st.floats(min_value=-1, max_value=1.0))
def test_net_create(self, blob_name, net_name, value): def test_net_create(self, blob_name, net_name, value):
blob_name = blob_name.encode('ascii', 'ignore')
ws = workspace.C.Workspace() ws = workspace.C.Workspace()
net = core.Net(net_name) net = core.Net(net_name)
net.ConstantFill([], [blob_name], shape=[1], value=value) net.ConstantFill([], [blob_name], shape=[1], value=value)