Core unit test fixes for Python 3

Summary: As title

Differential Revision: D5291327

fbshipit-source-id: 7dd9279c53ba55d3422c31973ffcec5705787fdf
This commit is contained in:
Thomas Dudziak 2017-06-23 13:06:20 -07:00 committed by Facebook Github Bot
parent ff914bf201
commit 342de07231
18 changed files with 167 additions and 131 deletions

View File

@ -8,13 +8,7 @@ from __future__ import unicode_literals
import sys
import copy
import inspect
try:
from past.builtins import basestring
except ImportError:
print("You don't have the past package installed. ",
"This is necessary for python 2/3 compatibility. ",
"To do this, do 'pip install future'.")
sys.exit(1)
from past.builtins import basestring
from caffe2.python.model_helper import ModelHelper
# flake8: noqa

View File

@ -6,14 +6,8 @@ from __future__ import print_function
from __future__ import unicode_literals
from collections import namedtuple, OrderedDict
try:
from past.builtins import basestring
except ImportError:
print("You don't have the past package installed. ",
"This is necessary for python 2/3 compatibility. ",
"To do this, do 'pip install future'.")
import sys
sys.exit(1)
from past.builtins import basestring
from six import binary_type, string_types, text_type
from caffe2.proto import caffe2_pb2
from collections import defaultdict
@ -138,7 +132,12 @@ class BlobReference(object):
Note that this does not prepends the namescope. If needed, use
ScopedBlobReference() to prepend the existing namespace.
"""
self._name = name
if isinstance(name, string_types):
self._name = name
elif isinstance(name, binary_type):
self._name = name.decode('utf-8')
else:
self._name = str(name)
self._from_net = net
# meta allows helper functions to put whatever metainformation needed
# there.
@ -148,8 +147,10 @@ class BlobReference(object):
return hash(self._name)
def __eq__(self, other):
if isinstance(other, basestring):
if isinstance(other, string_types):
return self._name == other
elif isinstance(other, binary_type):
return self._name == other.decode('utf-8')
elif isinstance(other, BlobReference):
return self._name == other._name
else:
@ -165,12 +166,12 @@ class BlobReference(object):
return 'BlobReference("{}")'.format(self._name)
def __add__(self, other):
if not isinstance(other, basestring):
if not isinstance(other, string_types):
raise RuntimeError('Cannot add BlobReference to a non-string.')
return BlobReference(self._name + other, self._from_net)
def __radd__(self, other):
if not isinstance(other, basestring):
if not isinstance(other, string_types):
raise RuntimeError('Cannot add a non-string to BlobReference.')
return BlobReference(other + self._name, self._from_net)
@ -185,7 +186,7 @@ class BlobReference(object):
network's __getattr__ function.
"""
inputs = [] if inputs is None else inputs
if isinstance(inputs, BlobReference) or isinstance(inputs, str):
if isinstance(inputs, BlobReference) or isinstance(inputs, string_types):
inputs = [inputs]
# add self to the input list.
inputs.insert(0, self)
@ -228,7 +229,7 @@ class BlobReference(object):
def ScopedName(name):
"""prefix the name with the current scope."""
if isinstance(name, bytes):
if isinstance(name, binary_type):
name = name.decode('ascii')
return scope.CurrentNameScope() + name
@ -242,7 +243,7 @@ def _RectifyInputOutput(blobs, net=None):
"""A helper function to rectify the input or output of the CreateOperator
interface.
"""
if isinstance(blobs, basestring):
if isinstance(blobs, string_types):
# If blobs is a single string, prepend scope.CurrentNameScope()
# and put it as a list.
# TODO(jiayq): enforce using BlobReference instead of raw strings.
@ -254,7 +255,7 @@ def _RectifyInputOutput(blobs, net=None):
# If blob is a list, we go through it and type check.
rectified = []
for blob in blobs:
if isinstance(blob, basestring):
if isinstance(blob, string_types) or isinstance(blob, binary_type):
rectified.append(ScopedBlobReference(blob, net=net))
elif type(blob) is BlobReference:
rectified.append(blob)
@ -291,11 +292,11 @@ def CreateOperator(
# Add rectified inputs and outputs
inputs = _RectifyInputOutput(inputs)
outputs = _RectifyInputOutput(outputs)
operator.input.extend([str(i) for i in inputs])
operator.output.extend([str(o) for o in outputs])
operator.input.extend([text_type(i) for i in inputs])
operator.output.extend([text_type(o) for o in outputs])
if control_input:
control_input = _RectifyInputOutput(control_input)
operator.control_input.extend([str(i) for i in control_input])
operator.control_input.extend([text_type(i) for i in control_input])
# Set device option:
# (1) If device_option is explicitly set, use device_option.
# (2) If not, but scope.CurrentDeviceScope() is set,
@ -926,9 +927,13 @@ class IR(object):
all_input_to_grad_out = {}
for key, val in all_input_to_grad.items():
if val is not None:
all_input_to_grad_out[BlobReference(key)] = (
BlobReference(val) if isinstance(val, basestring) else
GradientSlice(BlobReference(val[0]), BlobReference(val[1])))
if (isinstance(val, string_types) or
isinstance(val, binary_type)):
grad_out = BlobReference(val)
else:
grad_out = GradientSlice(BlobReference(val[0]),
BlobReference(val[1]))
all_input_to_grad_out[BlobReference(key)] = grad_out
return all_gradient_ops, all_input_to_grad_out
@ -1599,7 +1604,7 @@ class Net(object):
def _ExtendOps(self, new_ops):
self._net.op.extend(new_ops)
for op in new_ops:
self._op_outputs.update([str(o) for o in op.output])
self._op_outputs.update([text_type(o) for o in op.output])
def _CheckLookupTables(self):
'''
@ -1691,7 +1696,7 @@ class Net(object):
def AddScopedExternalInputs(self, *inputs):
res = self.AddExternalInput(
* [ScopedBlobReference(str(b)) for b in inputs]
* [ScopedBlobReference(b) for b in inputs]
)
if not isinstance(res, list):
res = [res]
@ -1699,7 +1704,7 @@ class Net(object):
def AddScopedExternalOutputs(self, *outputs):
return self.AddExternalOutput(
* [ScopedBlobReference(str(b)) for b in outputs]
* [ScopedBlobReference(b) for b in outputs]
)
@property
@ -1826,9 +1831,9 @@ class Net(object):
if len(op.output) == 0:
return
elif len(op.output) == 1:
return BlobReference(str(op.output[0]), self)
return BlobReference(op.output[0], self)
else:
return tuple(BlobReference(str(o), self) for o in op.output)
return tuple(BlobReference(o, self) for o in op.output)
def __getattr__(self, op_type):
if op_type.startswith('__'):

View File

@ -1114,7 +1114,7 @@ def _InferBlobDevice(model):
step_args = [a for a in op.arg if a.name.endswith("step_net")]
for step_arg in step_args:
step_proto = caffe2_pb2.NetDef()
protobuftx.Merge(step_arg.s, step_proto)
protobuftx.Merge(step_arg.s.decode("ascii"), step_proto)
map_ops(step_proto)
map_ops(model.net.Proto())
model._blob_to_device = mapping

View File

@ -453,7 +453,8 @@ class SparseDataParallelModelTest(TestCase):
workspace.RunNet(model.net.Proto().name)
if len(gpu_devices) == 2:
open("dump.txt", "w").write(str(model.net.Proto()))
with open("/tmp/dump.txt", "w") as f:
f.write(str(model.net.Proto()))
if not cpu_indices:
idx = workspace.FetchBlob("gpu_0/indices")
idx = list(idx.flatten())

View File

@ -812,10 +812,11 @@ class TestOperators(hu.HypothesisTestCase):
N = prediction.shape[0]
correct = 0
for i in range(0, len(prediction)):
# we no longer have cmp function in python 3
pred_sorted = sorted([
[item, j] for j, item in enumerate(prediction[i])],
cmp=lambda x, y: int(y[0] > x[0]) - int(y[0] < x[0]))
pred_sorted = sorted(
([item, j] for j, item in enumerate(prediction[i])),
key=lambda x: x[0],
reverse=True
)
max_ids = [x[1] for x in pred_sorted[0:top_k]]
for m in max_ids:
if m == labels[i]:
@ -889,7 +890,7 @@ class TestOperators(hu.HypothesisTestCase):
def op_ref(lengths):
sids = []
for _, l in enumerate(lengths):
sids.extend(range(l))
sids.extend(list(range(l)))
return (np.array(sids, dtype=np.int32), )
self.assertReferenceChecks(
@ -1122,7 +1123,11 @@ class TestOperators(hu.HypothesisTestCase):
original matrices.
"""
import threading
import Queue
try:
import queue
except ImportError:
# Py3
import Queue as queue
op = core.CreateOperator(
"CreateBlobsQueue",
[],
@ -1134,7 +1139,7 @@ class TestOperators(hu.HypothesisTestCase):
xs = [np.random.randn(num_elements, 5).astype(np.float32)
for _ in range(num_blobs)]
q = Queue.Queue()
q = queue.Queue()
for i in range(num_elements):
q.put([x[i] for x in xs])
@ -1152,7 +1157,7 @@ class TestOperators(hu.HypothesisTestCase):
self.ws.create_blob(feed_blob).feed(
elem, device_option=do)
self.ws.run(op)
except Queue.Empty:
except queue.Empty:
return
# Create all blobs before racing on multiple threads
@ -1840,7 +1845,7 @@ class TestOperators(hu.HypothesisTestCase):
backward_link_internal=backward_link_internal,
backward_link_external=backward_link_external,
backward_link_offset=backward_link_offset,
param=map(inputs.index, step_net.params),
param=[inputs.index(p) for p in step_net.params],
step_net=str(step_net.Proto()),
backward_step_net=str(backward_step_net.Proto()),
outputs_with_grads=[0],

View File

@ -104,10 +104,13 @@ class LayersTestCase(test_util.TestCase):
def assertArgsEqual(self, spec_args, op_args):
self.assertEqual(len(spec_args), len(op_args))
keys = [a.name for a in op_args]
def parse_args(args):
operator = caffe2_pb2.OperatorDef()
for k, v in args.items():
# Generate the expected value in the same order
for k in keys:
v = args[k]
arg = utils.MakeArgument(k, v)
operator.arg.add().CopyFrom(arg)
return operator.arg

View File

@ -852,18 +852,18 @@ def apply_recurrent_blob_assignments(op, blob_assignments, canonical_name):
step_args = [a for a in op.arg if a.name.endswith("step_net")]
for step_arg in step_args:
step_proto = caffe2_pb2.NetDef()
protobuftx.Merge(step_arg.s, step_proto)
protobuftx.Merge(step_arg.s.decode("ascii"), step_proto)
apply_assignments(step_proto, blob_assignments)
for i, einp in enumerate(step_proto.external_input):
if einp in blob_assignments:
step_proto.external_input[i] = canonical_name(einp)
step_arg.s = str(step_proto)
step_arg.s = str(step_proto).encode("ascii")
# Store renamings
for blob, renamed in blob_assignments.items():
if blob in list(op.input) + list(op.output):
a = caffe2_pb2.Argument()
a.name = blob + ".rename"
a.s = str(renamed)
a.s = str(renamed).encode("ascii")
op.arg.extend([a])

View File

@ -531,10 +531,10 @@ def ExtractPredictorNet(
import google.protobuf.text_format as protobuftx
for arg in op.arg:
if arg.name == 'backward_step_net':
arg.s = str("")
arg.s = b""
elif arg.name == 'step_net':
step_proto = caffe2_pb2.NetDef()
protobuftx.Merge(arg.s, step_proto)
protobuftx.Merge(arg.s.decode("ascii"), step_proto)
for step_op in step_proto.op:
if device is not None:
step_op.device_option.device_type = device.device_type
@ -546,7 +546,7 @@ def ExtractPredictorNet(
orig_external_inputs
)
)
arg.s = str(step_proto)
arg.s = str(step_proto).encode("ascii")
if device is not None:
op.device_option.device_type = device.device_type

View File

@ -20,7 +20,7 @@ import hypothesis.strategies as st
def _assert_arrays_equal(actual, ref, err_msg):
if ref.dtype.kind in ('S', 'O'):
if ref.dtype.kind in ('S', 'O', 'U'):
np.testing.assert_array_equal(actual, ref, err_msg=err_msg)
else:
np.testing.assert_allclose(

View File

@ -119,7 +119,8 @@ class TestHsm(hu.HypothesisTestCase):
for i in range(names.shape[0]):
for j in range(names.shape[1]):
if names[i][j]:
assert(names[i][j] == p_names[i][j])
self.assertEquals(
names[i][j], p_names[i][j].item().encode('utf-8'))
self.assertAlmostEqual(
scores[i][j], p_scores[i][j], delta=0.001)

View File

@ -7,13 +7,17 @@ import unittest
try:
import cv2
except ImportError:
raise unittest.SkipTest('python-opencv is not installed')
pass # Handled below
from PIL import Image
import numpy as np
import lmdb
import shutil
import StringIO
try:
import StringIO
except ImportError:
from io import StringIO
import sys
import tempfile
# TODO: This test does not test scaling because
@ -194,6 +198,7 @@ def create_test(output_dir, width, height, default_bound,
return expected_results
@unittest.skipIf('cv2' not in sys.modules, 'python-opencv is not installed')
class TestImport(hu.HypothesisTestCase):
@given(size_tuple=st.tuples(
st.integers(min_value=8, max_value=4096),

View File

@ -64,8 +64,14 @@ class TestIndexOps(TestCase):
['stored_entries']))
stored_actual = workspace.FetchBlob('stored_entries')
new_entries = np.array([entries[3], entries[4]], dtype=dtype)
np.testing.assert_array_equal(
np.concatenate((my_entries, new_entries)), stored_actual)
expected = np.concatenate((my_entries, new_entries))
if dtype is str:
# we'll always get bytes back from Caffe2
expected = np.array([
x.item().encode('utf-8') if isinstance(x, np.str_) else x
for x in expected
], dtype=object)
np.testing.assert_array_equal(expected, stored_actual)
workspace.RunOperatorOnce(core.CreateOperator(
index_create_op,

View File

@ -458,10 +458,9 @@ def prepare_mul_rnn(model, input_blob, shape, T, outputs_with_grad, num_layers):
model=model,
inputs=input_blob,
initial_states=states,
outputs_with_grads=map(
lambda x: x + 2 * (num_layers - 1),
outputs_with_grad
),
outputs_with_grads=[
x + 2 * (num_layers - 1) for x in outputs_with_grad
],
seq_lengths=None,
)
return results[-2:]
@ -682,12 +681,12 @@ class RNNCellTest(hu.HypothesisTestCase):
for arg in op.arg:
if arg.name == "step_net":
step_proto = caffe2_pb2.NetDef()
protobuftx.Merge(arg.s, step_proto)
protobuftx.Merge(arg.s.decode("ascii"), step_proto)
for step_op in step_proto.op:
self.assertEqual(0, step_op.device_option.device_type)
self.assertEqual(1, step_op.device_option.cuda_gpu_id)
elif arg.name == 'backward_step_net':
self.assertEqual("", arg.s)
self.assertEqual(b"", arg.s)
@given(encoder_output_length=st.integers(1, 3),

View File

@ -19,7 +19,10 @@ class TestCounterOps(TestCase):
existing = len(previous_keys)
prefix = '/'.join([__name__, 'TestCounterOps', 'test_stats_ops'])
keys = [prefix + '/key1', prefix + '/key2']
keys = [
(prefix + '/key1').encode('ascii'),
(prefix + '/key2').encode('ascii')
]
values = [34, 45]
workspace.FeedBlob('k', np.array(keys, dtype=str))
workspace.FeedBlob('v', np.array(values, dtype=np.int64))

View File

@ -8,7 +8,6 @@ from caffe2.python.test_util import TestCase
from caffe2.python.schema import Struct, Scalar, FetchRecord
import tempfile
import numpy as np
import os
class TestTextFileReader(TestCase):
@ -24,46 +23,44 @@ class TestTextFileReader(TestCase):
[0.456, 0.789, 0.10101, -24342.64],
]
row_data = list(zip(*col_data))
txt_file = tempfile.NamedTemporaryFile(delete=False)
txt_file.write(
'\n'.join(
'\t'.join(str(x) for x in f)
for f in row_data
) + '\n'
)
txt_file.close()
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as txt_file:
txt_file.write(
'\n'.join(
'\t'.join(str(x) for x in f)
for f in row_data
) + '\n'
)
txt_file.flush()
for num_passes in range(1, 3):
for batch_size in range(1, len(row_data) + 2):
init_net = core.Net('init_net')
reader = TextFileReader(
init_net,
filename=txt_file.name,
schema=schema,
batch_size=batch_size,
num_passes=num_passes)
workspace.RunNetOnce(init_net)
for num_passes in range(1, 3):
for batch_size in range(1, len(row_data) + 2):
init_net = core.Net('init_net')
reader = TextFileReader(
init_net,
filename=txt_file.name,
schema=schema,
batch_size=batch_size,
num_passes=num_passes)
workspace.RunNetOnce(init_net)
net = core.Net('read_net')
should_stop, record = reader.read_record(net)
net = core.Net('read_net')
should_stop, record = reader.read_record(net)
results = [np.array([])] * num_fields
while True:
workspace.RunNetOnce(net)
arrays = FetchRecord(record).field_blobs()
results = [np.array([])] * num_fields
while True:
workspace.RunNetOnce(net)
arrays = FetchRecord(record).field_blobs()
for i in range(num_fields):
results[i] = np.append(results[i], arrays[i])
if workspace.FetchBlob(should_stop):
break
for i in range(num_fields):
results[i] = np.append(results[i], arrays[i])
if workspace.FetchBlob(should_stop):
break
for i in range(num_fields):
col_batch = np.tile(col_data[i], num_passes)
if col_batch.dtype in (np.float32, np.float64):
np.testing.assert_array_almost_equal(
col_batch, results[i], decimal=3)
else:
np.testing.assert_array_equal(col_batch, results[i])
os.remove(txt_file.name)
col_batch = np.tile(col_data[i], num_passes)
if col_batch.dtype in (np.float32, np.float64):
np.testing.assert_array_almost_equal(
col_batch, results[i], decimal=3)
else:
np.testing.assert_array_equal(col_batch, results[i])
if __name__ == "__main__":
import unittest

View File

@ -231,7 +231,7 @@ PythonOpBase::PythonOpBase(
CAFFE_ENFORCE(pickle);
auto loads = pickle.attr("loads").cast<py::object>();
CAFFE_ENFORCE(loads);
auto builder_call = loads(pickled).cast<py::tuple>();
auto builder_call = loads(py::bytes(pickled)).cast<py::tuple>();
CAFFE_ENFORCE(builder_call);
CAFFE_ENFORCE_EQ(py::len(builder_call), 3);
auto func = builder_call[0].cast<py::object>();
@ -504,7 +504,8 @@ void addObjectMethods(py::module& m) {
"_create_net",
[](Workspace* self, py::bytes def, bool overwrite) -> py::object {
caffe2::NetDef proto;
CAFFE_ENFORCE(ParseProtobufFromLargeString(def, &proto));
CAFFE_ENFORCE(
ParseProtobufFromLargeString(def.cast<std::string>(), &proto));
auto* net = self->CreateNet(proto, overwrite);
CAFFE_ENFORCE(net);
return py::cast(net, py::return_value_policy::reference_internal);
@ -527,7 +528,8 @@ void addObjectMethods(py::module& m) {
"_run_net",
[](Workspace* self, py::bytes def) {
caffe2::NetDef proto;
CAFFE_ENFORCE(ParseProtobufFromLargeString(def, &proto));
CAFFE_ENFORCE(
ParseProtobufFromLargeString(def.cast<std::string>(), &proto));
py::gil_scoped_release g;
CAFFE_ENFORCE(self->RunNetOnce(proto));
})
@ -535,7 +537,8 @@ void addObjectMethods(py::module& m) {
"_run_operator",
[](Workspace* self, py::bytes def) {
caffe2::OperatorDef proto;
CAFFE_ENFORCE(ParseProtobufFromLargeString(def, &proto));
CAFFE_ENFORCE(
ParseProtobufFromLargeString(def.cast<std::string>(), &proto));
py::gil_scoped_release g;
CAFFE_ENFORCE(self->RunOperatorOnce(proto));
})
@ -543,7 +546,8 @@ void addObjectMethods(py::module& m) {
"_run_plan",
[](Workspace* self, py::bytes def) {
caffe2::PlanDef proto;
CAFFE_ENFORCE(ParseProtobufFromLargeString(def, &proto));
CAFFE_ENFORCE(
ParseProtobufFromLargeString(def.cast<std::string>(), &proto));
py::gil_scoped_release g;
CAFFE_ENFORCE(self->RunPlan(proto));
})
@ -568,7 +572,8 @@ void addObjectMethods(py::module& m) {
"get_gradient_defs",
[](py::bytes op_def, std::vector<GradientWrapper> output_gradients) {
OperatorDef def;
CAFFE_ENFORCE(ParseProtobufFromLargeString(op_def, &def));
CAFFE_ENFORCE(
ParseProtobufFromLargeString(op_def.cast<std::string>(), &def));
CAFFE_ENFORCE(caffe2::GradientRegistry()->Has(def.type()));
const auto& meta = GetGradientForOp(def, output_gradients);
std::vector<py::bytes> grad_ops;
@ -639,9 +644,10 @@ void addObjectMethods(py::module& m) {
[](Predictor& instance, py::bytes init_net, py::bytes predict_net) {
CAFFE_ENFORCE(gWorkspace);
NetDef init_net_, predict_net_;
CAFFE_ENFORCE(ParseProtobufFromLargeString(init_net, &init_net_));
CAFFE_ENFORCE(
ParseProtobufFromLargeString(predict_net, &predict_net_));
CAFFE_ENFORCE(ParseProtobufFromLargeString(
init_net.cast<std::string>(), &init_net_));
CAFFE_ENFORCE(ParseProtobufFromLargeString(
predict_net.cast<std::string>(), &predict_net_));
new (&instance) Predictor(init_net_, predict_net_, gWorkspace);
})
.def(
@ -781,15 +787,16 @@ void addGlobalMethods(py::module& m) {
m.def(
"create_net",
[](py::bytes net_def, bool overwrite) {
CAFFE_ENFORCE(gWorkspace);
caffe2::NetDef proto;
CAFFE_ENFORCE(
ParseProtobufFromLargeString(net_def, &proto),
ParseProtobufFromLargeString(net_def.cast<std::string>(), &proto),
"Can't parse net proto: ",
std::string(net_def));
net_def.cast<std::string>());
CAFFE_ENFORCE(
gWorkspace->CreateNet(proto, overwrite),
"Error creating net with proto: ",
std::string(net_def));
net_def.cast<std::string>());
return true;
},
py::arg("net_def"),
@ -834,7 +841,8 @@ void addGlobalMethods(py::module& m) {
m.def("run_operator_once", [](const py::bytes& op_def) {
CAFFE_ENFORCE(gWorkspace);
OperatorDef def;
CAFFE_ENFORCE(ParseProtobufFromLargeString(op_def, &def));
CAFFE_ENFORCE(
ParseProtobufFromLargeString(op_def.cast<std::string>(), &def));
py::gil_scoped_release g;
CAFFE_ENFORCE(gWorkspace->RunOperatorOnce(def));
return true;
@ -842,16 +850,17 @@ void addGlobalMethods(py::module& m) {
m.def("run_net_once", [](const py::bytes& net_def) {
CAFFE_ENFORCE(gWorkspace);
NetDef def;
CAFFE_ENFORCE(ParseProtobufFromLargeString(net_def, &def));
CAFFE_ENFORCE(
ParseProtobufFromLargeString(net_def.cast<std::string>(), &def));
py::gil_scoped_release g;
CAFFE_ENFORCE(gWorkspace->RunNetOnce(def));
return true;
});
m.def("run_plan", [](const py::bytes& plan_def) {
CAFFE_ENFORCE(gWorkspace);
const std::string& msg = std::move(plan_def);
PlanDef def;
CAFFE_ENFORCE(ParseProtobufFromLargeString(msg, &def));
CAFFE_ENFORCE(
ParseProtobufFromLargeString(plan_def.cast<std::string>(), &def));
py::gil_scoped_release g;
CAFFE_ENFORCE(gWorkspace->RunPlan(def));
return true;

View File

@ -11,7 +11,7 @@ import sys
import collections
import functools
import numpy as np
from six import integer_types, string_types, text_type
from six import integer_types, binary_type, text_type
def CaffeBlobToNumpyArray(blob):
@ -80,9 +80,10 @@ def MakeArgument(key, value):
# We make a relaxation that a boolean variable will also be stored as
# int.
argument.i = value
elif isinstance(value, string_types):
argument.s = (value.encode('utf-8') if isinstance(value, text_type)
else value)
elif isinstance(value, binary_type):
argument.s = value
elif isinstance(value, text_type):
argument.s = value.encode('utf-8')
elif isinstance(value, Message):
argument.s = value.SerializeToString()
elif iterable and all(type(v) in [float, np.float_] for v in value):
@ -95,7 +96,9 @@ def MakeArgument(key, value):
argument.ints.extend(
v.item() if type(v) is np.int_ else v for v in value
)
elif iterable and all(isinstance(v, string_types) for v in value):
elif iterable and all(
isinstance(v, binary_type) or isinstance(v, text_type) for v in value
):
argument.strings.extend(
v.encode('utf-8') if isinstance(v, text_type) else v
for v in value
@ -103,10 +106,19 @@ def MakeArgument(key, value):
elif iterable and all(isinstance(v, Message) for v in value):
argument.strings.extend(v.SerializeToString() for v in value)
else:
raise ValueError(
"Unknown argument type: key=%s value=%s, value type=%s" %
(key, str(value), str(type(value)))
)
if iterable:
raise ValueError(
"Unknown iterable argument type: key={} value={}, value "
"type={}[{}]".format(
key, value, type(value), set(type(v) for v in value)
)
)
else:
raise ValueError(
"Unknown argument type: key={} value={}, value type={}".format(
key, value, type(value)
)
)
return argument

View File

@ -385,7 +385,6 @@ class TestCWorkspace(htu.HypothesisTestCase):
@given(name=st.text(), value=st.floats(min_value=-1, max_value=1.0))
def test_operator_run(self, name, value):
name = name.encode('ascii', 'ignore')
ws = workspace.C.Workspace()
op = core.CreateOperator(
"ConstantFill", [], [name], shape=[1], value=value)
@ -398,7 +397,6 @@ class TestCWorkspace(htu.HypothesisTestCase):
net_name=st.text(),
value=st.floats(min_value=-1, max_value=1.0))
def test_net_run(self, blob_name, net_name, value):
blob_name = blob_name.encode('ascii', 'ignore')
ws = workspace.C.Workspace()
net = core.Net(net_name)
net.ConstantFill([], [blob_name], shape=[1], value=value)
@ -413,7 +411,6 @@ class TestCWorkspace(htu.HypothesisTestCase):
plan_name=st.text(),
value=st.floats(min_value=-1, max_value=1.0))
def test_plan_run(self, blob_name, plan_name, net_name, value):
blob_name = blob_name.encode('ascii', 'ignore')
ws = workspace.C.Workspace()
plan = core.Plan(plan_name)
net = core.Net(net_name)
@ -431,7 +428,6 @@ class TestCWorkspace(htu.HypothesisTestCase):
net_name=st.text(),
value=st.floats(min_value=-1, max_value=1.0))
def test_net_create(self, blob_name, net_name, value):
blob_name = blob_name.encode('ascii', 'ignore')
ws = workspace.C.Workspace()
net = core.Net(net_name)
net.ConstantFill([], [blob_name], shape=[1], value=value)