From 342de07231dc8beed1cf145eb32ddfa305238b54 Mon Sep 17 00:00:00 2001 From: Thomas Dudziak Date: Fri, 23 Jun 2017 13:06:20 -0700 Subject: [PATCH] Core unit test fixes for Python 3 Summary: As title Differential Revision: D5291327 fbshipit-source-id: 7dd9279c53ba55d3422c31973ffcec5705787fdf --- caffe2/python/brew.py | 8 +-- caffe2/python/core.py | 59 ++++++++------- caffe2/python/data_parallel_model.py | 2 +- caffe2/python/data_parallel_model_test.py | 3 +- caffe2/python/hypothesis_test.py | 23 +++--- caffe2/python/layer_test_util.py | 5 +- caffe2/python/memonger.py | 6 +- caffe2/python/model_helper.py | 6 +- .../python/operator_test/dataset_ops_test.py | 2 +- caffe2/python/operator_test/hsm_test.py | 3 +- .../operator_test/image_input_op_test.py | 9 ++- caffe2/python/operator_test/index_ops_test.py | 10 ++- caffe2/python/operator_test/rnn_cell_test.py | 11 ++- caffe2/python/operator_test/stats_ops_test.py | 5 +- .../operator_test/text_file_reader_test.py | 71 +++++++++---------- caffe2/python/pybind_state.cc | 41 ++++++----- caffe2/python/utils.py | 30 +++++--- caffe2/python/workspace_test.py | 4 -- 18 files changed, 167 insertions(+), 131 deletions(-) diff --git a/caffe2/python/brew.py b/caffe2/python/brew.py index e3991da9b69..87413b89804 100644 --- a/caffe2/python/brew.py +++ b/caffe2/python/brew.py @@ -8,13 +8,7 @@ from __future__ import unicode_literals import sys import copy import inspect -try: - from past.builtins import basestring -except ImportError: - print("You don't have the past package installed. ", - "This is necessary for python 2/3 compatibility. ", - "To do this, do 'pip install future'.") - sys.exit(1) +from past.builtins import basestring from caffe2.python.model_helper import ModelHelper # flake8: noqa diff --git a/caffe2/python/core.py b/caffe2/python/core.py index bb5634c2055..631e62df302 100644 --- a/caffe2/python/core.py +++ b/caffe2/python/core.py @@ -6,14 +6,8 @@ from __future__ import print_function from __future__ import unicode_literals from collections import namedtuple, OrderedDict -try: - from past.builtins import basestring -except ImportError: - print("You don't have the past package installed. ", - "This is necessary for python 2/3 compatibility. ", - "To do this, do 'pip install future'.") - import sys - sys.exit(1) +from past.builtins import basestring +from six import binary_type, string_types, text_type from caffe2.proto import caffe2_pb2 from collections import defaultdict @@ -138,7 +132,12 @@ class BlobReference(object): Note that this does not prepends the namescope. If needed, use ScopedBlobReference() to prepend the existing namespace. """ - self._name = name + if isinstance(name, string_types): + self._name = name + elif isinstance(name, binary_type): + self._name = name.decode('utf-8') + else: + self._name = str(name) self._from_net = net # meta allows helper functions to put whatever metainformation needed # there. @@ -148,8 +147,10 @@ class BlobReference(object): return hash(self._name) def __eq__(self, other): - if isinstance(other, basestring): + if isinstance(other, string_types): return self._name == other + elif isinstance(other, binary_type): + return self._name == other.decode('utf-8') elif isinstance(other, BlobReference): return self._name == other._name else: @@ -165,12 +166,12 @@ class BlobReference(object): return 'BlobReference("{}")'.format(self._name) def __add__(self, other): - if not isinstance(other, basestring): + if not isinstance(other, string_types): raise RuntimeError('Cannot add BlobReference to a non-string.') return BlobReference(self._name + other, self._from_net) def __radd__(self, other): - if not isinstance(other, basestring): + if not isinstance(other, string_types): raise RuntimeError('Cannot add a non-string to BlobReference.') return BlobReference(other + self._name, self._from_net) @@ -185,7 +186,7 @@ class BlobReference(object): network's __getattr__ function. """ inputs = [] if inputs is None else inputs - if isinstance(inputs, BlobReference) or isinstance(inputs, str): + if isinstance(inputs, BlobReference) or isinstance(inputs, string_types): inputs = [inputs] # add self to the input list. inputs.insert(0, self) @@ -228,7 +229,7 @@ class BlobReference(object): def ScopedName(name): """prefix the name with the current scope.""" - if isinstance(name, bytes): + if isinstance(name, binary_type): name = name.decode('ascii') return scope.CurrentNameScope() + name @@ -242,7 +243,7 @@ def _RectifyInputOutput(blobs, net=None): """A helper function to rectify the input or output of the CreateOperator interface. """ - if isinstance(blobs, basestring): + if isinstance(blobs, string_types): # If blobs is a single string, prepend scope.CurrentNameScope() # and put it as a list. # TODO(jiayq): enforce using BlobReference instead of raw strings. @@ -254,7 +255,7 @@ def _RectifyInputOutput(blobs, net=None): # If blob is a list, we go through it and type check. rectified = [] for blob in blobs: - if isinstance(blob, basestring): + if isinstance(blob, string_types) or isinstance(blob, binary_type): rectified.append(ScopedBlobReference(blob, net=net)) elif type(blob) is BlobReference: rectified.append(blob) @@ -291,11 +292,11 @@ def CreateOperator( # Add rectified inputs and outputs inputs = _RectifyInputOutput(inputs) outputs = _RectifyInputOutput(outputs) - operator.input.extend([str(i) for i in inputs]) - operator.output.extend([str(o) for o in outputs]) + operator.input.extend([text_type(i) for i in inputs]) + operator.output.extend([text_type(o) for o in outputs]) if control_input: control_input = _RectifyInputOutput(control_input) - operator.control_input.extend([str(i) for i in control_input]) + operator.control_input.extend([text_type(i) for i in control_input]) # Set device option: # (1) If device_option is explicitly set, use device_option. # (2) If not, but scope.CurrentDeviceScope() is set, @@ -926,9 +927,13 @@ class IR(object): all_input_to_grad_out = {} for key, val in all_input_to_grad.items(): if val is not None: - all_input_to_grad_out[BlobReference(key)] = ( - BlobReference(val) if isinstance(val, basestring) else - GradientSlice(BlobReference(val[0]), BlobReference(val[1]))) + if (isinstance(val, string_types) or + isinstance(val, binary_type)): + grad_out = BlobReference(val) + else: + grad_out = GradientSlice(BlobReference(val[0]), + BlobReference(val[1])) + all_input_to_grad_out[BlobReference(key)] = grad_out return all_gradient_ops, all_input_to_grad_out @@ -1599,7 +1604,7 @@ class Net(object): def _ExtendOps(self, new_ops): self._net.op.extend(new_ops) for op in new_ops: - self._op_outputs.update([str(o) for o in op.output]) + self._op_outputs.update([text_type(o) for o in op.output]) def _CheckLookupTables(self): ''' @@ -1691,7 +1696,7 @@ class Net(object): def AddScopedExternalInputs(self, *inputs): res = self.AddExternalInput( - * [ScopedBlobReference(str(b)) for b in inputs] + * [ScopedBlobReference(b) for b in inputs] ) if not isinstance(res, list): res = [res] @@ -1699,7 +1704,7 @@ class Net(object): def AddScopedExternalOutputs(self, *outputs): return self.AddExternalOutput( - * [ScopedBlobReference(str(b)) for b in outputs] + * [ScopedBlobReference(b) for b in outputs] ) @property @@ -1826,9 +1831,9 @@ class Net(object): if len(op.output) == 0: return elif len(op.output) == 1: - return BlobReference(str(op.output[0]), self) + return BlobReference(op.output[0], self) else: - return tuple(BlobReference(str(o), self) for o in op.output) + return tuple(BlobReference(o, self) for o in op.output) def __getattr__(self, op_type): if op_type.startswith('__'): diff --git a/caffe2/python/data_parallel_model.py b/caffe2/python/data_parallel_model.py index ba7c64f2d79..cd324ecb801 100644 --- a/caffe2/python/data_parallel_model.py +++ b/caffe2/python/data_parallel_model.py @@ -1114,7 +1114,7 @@ def _InferBlobDevice(model): step_args = [a for a in op.arg if a.name.endswith("step_net")] for step_arg in step_args: step_proto = caffe2_pb2.NetDef() - protobuftx.Merge(step_arg.s, step_proto) + protobuftx.Merge(step_arg.s.decode("ascii"), step_proto) map_ops(step_proto) map_ops(model.net.Proto()) model._blob_to_device = mapping diff --git a/caffe2/python/data_parallel_model_test.py b/caffe2/python/data_parallel_model_test.py index 1aeb6473e63..edb2ad9de83 100644 --- a/caffe2/python/data_parallel_model_test.py +++ b/caffe2/python/data_parallel_model_test.py @@ -453,7 +453,8 @@ class SparseDataParallelModelTest(TestCase): workspace.RunNet(model.net.Proto().name) if len(gpu_devices) == 2: - open("dump.txt", "w").write(str(model.net.Proto())) + with open("/tmp/dump.txt", "w") as f: + f.write(str(model.net.Proto())) if not cpu_indices: idx = workspace.FetchBlob("gpu_0/indices") idx = list(idx.flatten()) diff --git a/caffe2/python/hypothesis_test.py b/caffe2/python/hypothesis_test.py index 8fcc81a96eb..18712e730f7 100644 --- a/caffe2/python/hypothesis_test.py +++ b/caffe2/python/hypothesis_test.py @@ -812,10 +812,11 @@ class TestOperators(hu.HypothesisTestCase): N = prediction.shape[0] correct = 0 for i in range(0, len(prediction)): - # we no longer have cmp function in python 3 - pred_sorted = sorted([ - [item, j] for j, item in enumerate(prediction[i])], - cmp=lambda x, y: int(y[0] > x[0]) - int(y[0] < x[0])) + pred_sorted = sorted( + ([item, j] for j, item in enumerate(prediction[i])), + key=lambda x: x[0], + reverse=True + ) max_ids = [x[1] for x in pred_sorted[0:top_k]] for m in max_ids: if m == labels[i]: @@ -889,7 +890,7 @@ class TestOperators(hu.HypothesisTestCase): def op_ref(lengths): sids = [] for _, l in enumerate(lengths): - sids.extend(range(l)) + sids.extend(list(range(l))) return (np.array(sids, dtype=np.int32), ) self.assertReferenceChecks( @@ -1122,7 +1123,11 @@ class TestOperators(hu.HypothesisTestCase): original matrices. """ import threading - import Queue + try: + import queue + except ImportError: + # Py3 + import Queue as queue op = core.CreateOperator( "CreateBlobsQueue", [], @@ -1134,7 +1139,7 @@ class TestOperators(hu.HypothesisTestCase): xs = [np.random.randn(num_elements, 5).astype(np.float32) for _ in range(num_blobs)] - q = Queue.Queue() + q = queue.Queue() for i in range(num_elements): q.put([x[i] for x in xs]) @@ -1152,7 +1157,7 @@ class TestOperators(hu.HypothesisTestCase): self.ws.create_blob(feed_blob).feed( elem, device_option=do) self.ws.run(op) - except Queue.Empty: + except queue.Empty: return # Create all blobs before racing on multiple threads @@ -1840,7 +1845,7 @@ class TestOperators(hu.HypothesisTestCase): backward_link_internal=backward_link_internal, backward_link_external=backward_link_external, backward_link_offset=backward_link_offset, - param=map(inputs.index, step_net.params), + param=[inputs.index(p) for p in step_net.params], step_net=str(step_net.Proto()), backward_step_net=str(backward_step_net.Proto()), outputs_with_grads=[0], diff --git a/caffe2/python/layer_test_util.py b/caffe2/python/layer_test_util.py index 3086ca79e22..d8b351eaa0f 100644 --- a/caffe2/python/layer_test_util.py +++ b/caffe2/python/layer_test_util.py @@ -104,10 +104,13 @@ class LayersTestCase(test_util.TestCase): def assertArgsEqual(self, spec_args, op_args): self.assertEqual(len(spec_args), len(op_args)) + keys = [a.name for a in op_args] def parse_args(args): operator = caffe2_pb2.OperatorDef() - for k, v in args.items(): + # Generate the expected value in the same order + for k in keys: + v = args[k] arg = utils.MakeArgument(k, v) operator.arg.add().CopyFrom(arg) return operator.arg diff --git a/caffe2/python/memonger.py b/caffe2/python/memonger.py index d3630689e49..1029214338c 100644 --- a/caffe2/python/memonger.py +++ b/caffe2/python/memonger.py @@ -852,18 +852,18 @@ def apply_recurrent_blob_assignments(op, blob_assignments, canonical_name): step_args = [a for a in op.arg if a.name.endswith("step_net")] for step_arg in step_args: step_proto = caffe2_pb2.NetDef() - protobuftx.Merge(step_arg.s, step_proto) + protobuftx.Merge(step_arg.s.decode("ascii"), step_proto) apply_assignments(step_proto, blob_assignments) for i, einp in enumerate(step_proto.external_input): if einp in blob_assignments: step_proto.external_input[i] = canonical_name(einp) - step_arg.s = str(step_proto) + step_arg.s = str(step_proto).encode("ascii") # Store renamings for blob, renamed in blob_assignments.items(): if blob in list(op.input) + list(op.output): a = caffe2_pb2.Argument() a.name = blob + ".rename" - a.s = str(renamed) + a.s = str(renamed).encode("ascii") op.arg.extend([a]) diff --git a/caffe2/python/model_helper.py b/caffe2/python/model_helper.py index 21fa73c44f3..bbdd13bf212 100644 --- a/caffe2/python/model_helper.py +++ b/caffe2/python/model_helper.py @@ -531,10 +531,10 @@ def ExtractPredictorNet( import google.protobuf.text_format as protobuftx for arg in op.arg: if arg.name == 'backward_step_net': - arg.s = str("") + arg.s = b"" elif arg.name == 'step_net': step_proto = caffe2_pb2.NetDef() - protobuftx.Merge(arg.s, step_proto) + protobuftx.Merge(arg.s.decode("ascii"), step_proto) for step_op in step_proto.op: if device is not None: step_op.device_option.device_type = device.device_type @@ -546,7 +546,7 @@ def ExtractPredictorNet( orig_external_inputs ) ) - arg.s = str(step_proto) + arg.s = str(step_proto).encode("ascii") if device is not None: op.device_option.device_type = device.device_type diff --git a/caffe2/python/operator_test/dataset_ops_test.py b/caffe2/python/operator_test/dataset_ops_test.py index db83c58d6c7..ab6645e250b 100644 --- a/caffe2/python/operator_test/dataset_ops_test.py +++ b/caffe2/python/operator_test/dataset_ops_test.py @@ -20,7 +20,7 @@ import hypothesis.strategies as st def _assert_arrays_equal(actual, ref, err_msg): - if ref.dtype.kind in ('S', 'O'): + if ref.dtype.kind in ('S', 'O', 'U'): np.testing.assert_array_equal(actual, ref, err_msg=err_msg) else: np.testing.assert_allclose( diff --git a/caffe2/python/operator_test/hsm_test.py b/caffe2/python/operator_test/hsm_test.py index 7d23a913c3d..26ee1de3331 100644 --- a/caffe2/python/operator_test/hsm_test.py +++ b/caffe2/python/operator_test/hsm_test.py @@ -119,7 +119,8 @@ class TestHsm(hu.HypothesisTestCase): for i in range(names.shape[0]): for j in range(names.shape[1]): if names[i][j]: - assert(names[i][j] == p_names[i][j]) + self.assertEquals( + names[i][j], p_names[i][j].item().encode('utf-8')) self.assertAlmostEqual( scores[i][j], p_scores[i][j], delta=0.001) diff --git a/caffe2/python/operator_test/image_input_op_test.py b/caffe2/python/operator_test/image_input_op_test.py index 4eaa4c43254..8576e986d35 100644 --- a/caffe2/python/operator_test/image_input_op_test.py +++ b/caffe2/python/operator_test/image_input_op_test.py @@ -7,13 +7,17 @@ import unittest try: import cv2 except ImportError: - raise unittest.SkipTest('python-opencv is not installed') + pass # Handled below from PIL import Image import numpy as np import lmdb import shutil -import StringIO +try: + import StringIO +except ImportError: + from io import StringIO +import sys import tempfile # TODO: This test does not test scaling because @@ -194,6 +198,7 @@ def create_test(output_dir, width, height, default_bound, return expected_results +@unittest.skipIf('cv2' not in sys.modules, 'python-opencv is not installed') class TestImport(hu.HypothesisTestCase): @given(size_tuple=st.tuples( st.integers(min_value=8, max_value=4096), diff --git a/caffe2/python/operator_test/index_ops_test.py b/caffe2/python/operator_test/index_ops_test.py index b121ff9676e..642f340fad8 100644 --- a/caffe2/python/operator_test/index_ops_test.py +++ b/caffe2/python/operator_test/index_ops_test.py @@ -64,8 +64,14 @@ class TestIndexOps(TestCase): ['stored_entries'])) stored_actual = workspace.FetchBlob('stored_entries') new_entries = np.array([entries[3], entries[4]], dtype=dtype) - np.testing.assert_array_equal( - np.concatenate((my_entries, new_entries)), stored_actual) + expected = np.concatenate((my_entries, new_entries)) + if dtype is str: + # we'll always get bytes back from Caffe2 + expected = np.array([ + x.item().encode('utf-8') if isinstance(x, np.str_) else x + for x in expected + ], dtype=object) + np.testing.assert_array_equal(expected, stored_actual) workspace.RunOperatorOnce(core.CreateOperator( index_create_op, diff --git a/caffe2/python/operator_test/rnn_cell_test.py b/caffe2/python/operator_test/rnn_cell_test.py index 71d66c488b6..8864c5b6c3f 100644 --- a/caffe2/python/operator_test/rnn_cell_test.py +++ b/caffe2/python/operator_test/rnn_cell_test.py @@ -458,10 +458,9 @@ def prepare_mul_rnn(model, input_blob, shape, T, outputs_with_grad, num_layers): model=model, inputs=input_blob, initial_states=states, - outputs_with_grads=map( - lambda x: x + 2 * (num_layers - 1), - outputs_with_grad - ), + outputs_with_grads=[ + x + 2 * (num_layers - 1) for x in outputs_with_grad + ], seq_lengths=None, ) return results[-2:] @@ -682,12 +681,12 @@ class RNNCellTest(hu.HypothesisTestCase): for arg in op.arg: if arg.name == "step_net": step_proto = caffe2_pb2.NetDef() - protobuftx.Merge(arg.s, step_proto) + protobuftx.Merge(arg.s.decode("ascii"), step_proto) for step_op in step_proto.op: self.assertEqual(0, step_op.device_option.device_type) self.assertEqual(1, step_op.device_option.cuda_gpu_id) elif arg.name == 'backward_step_net': - self.assertEqual("", arg.s) + self.assertEqual(b"", arg.s) @given(encoder_output_length=st.integers(1, 3), diff --git a/caffe2/python/operator_test/stats_ops_test.py b/caffe2/python/operator_test/stats_ops_test.py index 595161f76f1..edc36facb23 100644 --- a/caffe2/python/operator_test/stats_ops_test.py +++ b/caffe2/python/operator_test/stats_ops_test.py @@ -19,7 +19,10 @@ class TestCounterOps(TestCase): existing = len(previous_keys) prefix = '/'.join([__name__, 'TestCounterOps', 'test_stats_ops']) - keys = [prefix + '/key1', prefix + '/key2'] + keys = [ + (prefix + '/key1').encode('ascii'), + (prefix + '/key2').encode('ascii') + ] values = [34, 45] workspace.FeedBlob('k', np.array(keys, dtype=str)) workspace.FeedBlob('v', np.array(values, dtype=np.int64)) diff --git a/caffe2/python/operator_test/text_file_reader_test.py b/caffe2/python/operator_test/text_file_reader_test.py index 300a88dce35..41ba814af6a 100644 --- a/caffe2/python/operator_test/text_file_reader_test.py +++ b/caffe2/python/operator_test/text_file_reader_test.py @@ -8,7 +8,6 @@ from caffe2.python.test_util import TestCase from caffe2.python.schema import Struct, Scalar, FetchRecord import tempfile import numpy as np -import os class TestTextFileReader(TestCase): @@ -24,46 +23,44 @@ class TestTextFileReader(TestCase): [0.456, 0.789, 0.10101, -24342.64], ] row_data = list(zip(*col_data)) - txt_file = tempfile.NamedTemporaryFile(delete=False) - txt_file.write( - '\n'.join( - '\t'.join(str(x) for x in f) - for f in row_data - ) + '\n' - ) - txt_file.close() + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as txt_file: + txt_file.write( + '\n'.join( + '\t'.join(str(x) for x in f) + for f in row_data + ) + '\n' + ) + txt_file.flush() - for num_passes in range(1, 3): - for batch_size in range(1, len(row_data) + 2): - init_net = core.Net('init_net') - reader = TextFileReader( - init_net, - filename=txt_file.name, - schema=schema, - batch_size=batch_size, - num_passes=num_passes) - workspace.RunNetOnce(init_net) + for num_passes in range(1, 3): + for batch_size in range(1, len(row_data) + 2): + init_net = core.Net('init_net') + reader = TextFileReader( + init_net, + filename=txt_file.name, + schema=schema, + batch_size=batch_size, + num_passes=num_passes) + workspace.RunNetOnce(init_net) - net = core.Net('read_net') - should_stop, record = reader.read_record(net) + net = core.Net('read_net') + should_stop, record = reader.read_record(net) - results = [np.array([])] * num_fields - while True: - workspace.RunNetOnce(net) - arrays = FetchRecord(record).field_blobs() + results = [np.array([])] * num_fields + while True: + workspace.RunNetOnce(net) + arrays = FetchRecord(record).field_blobs() + for i in range(num_fields): + results[i] = np.append(results[i], arrays[i]) + if workspace.FetchBlob(should_stop): + break for i in range(num_fields): - results[i] = np.append(results[i], arrays[i]) - if workspace.FetchBlob(should_stop): - break - for i in range(num_fields): - col_batch = np.tile(col_data[i], num_passes) - if col_batch.dtype in (np.float32, np.float64): - np.testing.assert_array_almost_equal( - col_batch, results[i], decimal=3) - else: - np.testing.assert_array_equal(col_batch, results[i]) - - os.remove(txt_file.name) + col_batch = np.tile(col_data[i], num_passes) + if col_batch.dtype in (np.float32, np.float64): + np.testing.assert_array_almost_equal( + col_batch, results[i], decimal=3) + else: + np.testing.assert_array_equal(col_batch, results[i]) if __name__ == "__main__": import unittest diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index 747ccfbcc3f..164feb20518 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -231,7 +231,7 @@ PythonOpBase::PythonOpBase( CAFFE_ENFORCE(pickle); auto loads = pickle.attr("loads").cast(); CAFFE_ENFORCE(loads); - auto builder_call = loads(pickled).cast(); + auto builder_call = loads(py::bytes(pickled)).cast(); CAFFE_ENFORCE(builder_call); CAFFE_ENFORCE_EQ(py::len(builder_call), 3); auto func = builder_call[0].cast(); @@ -504,7 +504,8 @@ void addObjectMethods(py::module& m) { "_create_net", [](Workspace* self, py::bytes def, bool overwrite) -> py::object { caffe2::NetDef proto; - CAFFE_ENFORCE(ParseProtobufFromLargeString(def, &proto)); + CAFFE_ENFORCE( + ParseProtobufFromLargeString(def.cast(), &proto)); auto* net = self->CreateNet(proto, overwrite); CAFFE_ENFORCE(net); return py::cast(net, py::return_value_policy::reference_internal); @@ -527,7 +528,8 @@ void addObjectMethods(py::module& m) { "_run_net", [](Workspace* self, py::bytes def) { caffe2::NetDef proto; - CAFFE_ENFORCE(ParseProtobufFromLargeString(def, &proto)); + CAFFE_ENFORCE( + ParseProtobufFromLargeString(def.cast(), &proto)); py::gil_scoped_release g; CAFFE_ENFORCE(self->RunNetOnce(proto)); }) @@ -535,7 +537,8 @@ void addObjectMethods(py::module& m) { "_run_operator", [](Workspace* self, py::bytes def) { caffe2::OperatorDef proto; - CAFFE_ENFORCE(ParseProtobufFromLargeString(def, &proto)); + CAFFE_ENFORCE( + ParseProtobufFromLargeString(def.cast(), &proto)); py::gil_scoped_release g; CAFFE_ENFORCE(self->RunOperatorOnce(proto)); }) @@ -543,7 +546,8 @@ void addObjectMethods(py::module& m) { "_run_plan", [](Workspace* self, py::bytes def) { caffe2::PlanDef proto; - CAFFE_ENFORCE(ParseProtobufFromLargeString(def, &proto)); + CAFFE_ENFORCE( + ParseProtobufFromLargeString(def.cast(), &proto)); py::gil_scoped_release g; CAFFE_ENFORCE(self->RunPlan(proto)); }) @@ -568,7 +572,8 @@ void addObjectMethods(py::module& m) { "get_gradient_defs", [](py::bytes op_def, std::vector output_gradients) { OperatorDef def; - CAFFE_ENFORCE(ParseProtobufFromLargeString(op_def, &def)); + CAFFE_ENFORCE( + ParseProtobufFromLargeString(op_def.cast(), &def)); CAFFE_ENFORCE(caffe2::GradientRegistry()->Has(def.type())); const auto& meta = GetGradientForOp(def, output_gradients); std::vector grad_ops; @@ -639,9 +644,10 @@ void addObjectMethods(py::module& m) { [](Predictor& instance, py::bytes init_net, py::bytes predict_net) { CAFFE_ENFORCE(gWorkspace); NetDef init_net_, predict_net_; - CAFFE_ENFORCE(ParseProtobufFromLargeString(init_net, &init_net_)); - CAFFE_ENFORCE( - ParseProtobufFromLargeString(predict_net, &predict_net_)); + CAFFE_ENFORCE(ParseProtobufFromLargeString( + init_net.cast(), &init_net_)); + CAFFE_ENFORCE(ParseProtobufFromLargeString( + predict_net.cast(), &predict_net_)); new (&instance) Predictor(init_net_, predict_net_, gWorkspace); }) .def( @@ -781,15 +787,16 @@ void addGlobalMethods(py::module& m) { m.def( "create_net", [](py::bytes net_def, bool overwrite) { + CAFFE_ENFORCE(gWorkspace); caffe2::NetDef proto; CAFFE_ENFORCE( - ParseProtobufFromLargeString(net_def, &proto), + ParseProtobufFromLargeString(net_def.cast(), &proto), "Can't parse net proto: ", - std::string(net_def)); + net_def.cast()); CAFFE_ENFORCE( gWorkspace->CreateNet(proto, overwrite), "Error creating net with proto: ", - std::string(net_def)); + net_def.cast()); return true; }, py::arg("net_def"), @@ -834,7 +841,8 @@ void addGlobalMethods(py::module& m) { m.def("run_operator_once", [](const py::bytes& op_def) { CAFFE_ENFORCE(gWorkspace); OperatorDef def; - CAFFE_ENFORCE(ParseProtobufFromLargeString(op_def, &def)); + CAFFE_ENFORCE( + ParseProtobufFromLargeString(op_def.cast(), &def)); py::gil_scoped_release g; CAFFE_ENFORCE(gWorkspace->RunOperatorOnce(def)); return true; @@ -842,16 +850,17 @@ void addGlobalMethods(py::module& m) { m.def("run_net_once", [](const py::bytes& net_def) { CAFFE_ENFORCE(gWorkspace); NetDef def; - CAFFE_ENFORCE(ParseProtobufFromLargeString(net_def, &def)); + CAFFE_ENFORCE( + ParseProtobufFromLargeString(net_def.cast(), &def)); py::gil_scoped_release g; CAFFE_ENFORCE(gWorkspace->RunNetOnce(def)); return true; }); m.def("run_plan", [](const py::bytes& plan_def) { CAFFE_ENFORCE(gWorkspace); - const std::string& msg = std::move(plan_def); PlanDef def; - CAFFE_ENFORCE(ParseProtobufFromLargeString(msg, &def)); + CAFFE_ENFORCE( + ParseProtobufFromLargeString(plan_def.cast(), &def)); py::gil_scoped_release g; CAFFE_ENFORCE(gWorkspace->RunPlan(def)); return true; diff --git a/caffe2/python/utils.py b/caffe2/python/utils.py index 64568b1655a..a3c348da394 100644 --- a/caffe2/python/utils.py +++ b/caffe2/python/utils.py @@ -11,7 +11,7 @@ import sys import collections import functools import numpy as np -from six import integer_types, string_types, text_type +from six import integer_types, binary_type, text_type def CaffeBlobToNumpyArray(blob): @@ -80,9 +80,10 @@ def MakeArgument(key, value): # We make a relaxation that a boolean variable will also be stored as # int. argument.i = value - elif isinstance(value, string_types): - argument.s = (value.encode('utf-8') if isinstance(value, text_type) - else value) + elif isinstance(value, binary_type): + argument.s = value + elif isinstance(value, text_type): + argument.s = value.encode('utf-8') elif isinstance(value, Message): argument.s = value.SerializeToString() elif iterable and all(type(v) in [float, np.float_] for v in value): @@ -95,7 +96,9 @@ def MakeArgument(key, value): argument.ints.extend( v.item() if type(v) is np.int_ else v for v in value ) - elif iterable and all(isinstance(v, string_types) for v in value): + elif iterable and all( + isinstance(v, binary_type) or isinstance(v, text_type) for v in value + ): argument.strings.extend( v.encode('utf-8') if isinstance(v, text_type) else v for v in value @@ -103,10 +106,19 @@ def MakeArgument(key, value): elif iterable and all(isinstance(v, Message) for v in value): argument.strings.extend(v.SerializeToString() for v in value) else: - raise ValueError( - "Unknown argument type: key=%s value=%s, value type=%s" % - (key, str(value), str(type(value))) - ) + if iterable: + raise ValueError( + "Unknown iterable argument type: key={} value={}, value " + "type={}[{}]".format( + key, value, type(value), set(type(v) for v in value) + ) + ) + else: + raise ValueError( + "Unknown argument type: key={} value={}, value type={}".format( + key, value, type(value) + ) + ) return argument diff --git a/caffe2/python/workspace_test.py b/caffe2/python/workspace_test.py index 009be2d7fce..e9b8a10a7f9 100644 --- a/caffe2/python/workspace_test.py +++ b/caffe2/python/workspace_test.py @@ -385,7 +385,6 @@ class TestCWorkspace(htu.HypothesisTestCase): @given(name=st.text(), value=st.floats(min_value=-1, max_value=1.0)) def test_operator_run(self, name, value): - name = name.encode('ascii', 'ignore') ws = workspace.C.Workspace() op = core.CreateOperator( "ConstantFill", [], [name], shape=[1], value=value) @@ -398,7 +397,6 @@ class TestCWorkspace(htu.HypothesisTestCase): net_name=st.text(), value=st.floats(min_value=-1, max_value=1.0)) def test_net_run(self, blob_name, net_name, value): - blob_name = blob_name.encode('ascii', 'ignore') ws = workspace.C.Workspace() net = core.Net(net_name) net.ConstantFill([], [blob_name], shape=[1], value=value) @@ -413,7 +411,6 @@ class TestCWorkspace(htu.HypothesisTestCase): plan_name=st.text(), value=st.floats(min_value=-1, max_value=1.0)) def test_plan_run(self, blob_name, plan_name, net_name, value): - blob_name = blob_name.encode('ascii', 'ignore') ws = workspace.C.Workspace() plan = core.Plan(plan_name) net = core.Net(net_name) @@ -431,7 +428,6 @@ class TestCWorkspace(htu.HypothesisTestCase): net_name=st.text(), value=st.floats(min_value=-1, max_value=1.0)) def test_net_create(self, blob_name, net_name, value): - blob_name = blob_name.encode('ascii', 'ignore') ws = workspace.C.Workspace() net = core.Net(net_name) net.ConstantFill([], [blob_name], shape=[1], value=value)