Extend support to arbitrary ops in init net when converting c2 models to onnx (#7256)

This commit is contained in:
bddppq 2018-05-03 15:34:47 -07:00 committed by GitHub
parent 8091388d0f
commit a95b7b13f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 30 deletions

View File

@ -270,36 +270,12 @@ class Caffe2Frontend(object):
@classmethod
def caffe2_init_net_to_initializer(cls, init_net):
initializer = []
ws, _ = c2_native_run_net(init_net=None, predict_net=init_net, inputs=[])
output_names = []
for op in init_net.op:
assert not op.input
try:
data_type, field_name = {
'GivenTensorFill': (TensorProto.FLOAT, 'floats'),
'GivenTensorInt64Fill': (TensorProto.INT64, 'ints'),
'GivenTensorIntFill': (TensorProto.INT32, 'ints'),
'GivenTensorBoolFill': (TensorProto.BOOL, 'ints'),
'GivenTensorStringFill': (TensorProto.STRING, 'strings'),
}[op.type]
except KeyError:
raise RuntimeError(
"Can not translate init_net with operator '{}' "
"to initializer".format(op.type)
)
raw = (data_type != TensorProto.STRING)
args = {a.name: a for a in op.arg}
vals = getattr(args['values'], field_name)
if raw:
vals = np.asarray(
vals,
dtype=mapping.TENSOR_TYPE_TO_NP_TYPE[data_type]).tobytes()
initializer.append(make_tensor(
name=op.output[0],
data_type=data_type,
dims=args['shape'].ints,
vals=vals,
raw=raw,
))
output_names.extend(op.output)
initializer = [numpy_helper.from_array(ws.FetchBlob(name), name=name)
for name in sorted(set(output_names))]
return initializer
@classmethod

View File

@ -390,7 +390,6 @@ class TestCaffe2End2End(TestCase):
def test_inception_v2(self):
self._test_net('inception_v2')
@unittest.skip('Need to add support for ConstantFill operator')
def test_squeezenet(self):
self._test_net('squeezenet')