mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Extend support to arbitrary ops in init net when converting c2 models to onnx (#7256)
This commit is contained in:
parent
8091388d0f
commit
a95b7b13f9
|
|
@ -270,36 +270,12 @@ class Caffe2Frontend(object):
|
|||
|
||||
@classmethod
|
||||
def caffe2_init_net_to_initializer(cls, init_net):
|
||||
initializer = []
|
||||
ws, _ = c2_native_run_net(init_net=None, predict_net=init_net, inputs=[])
|
||||
output_names = []
|
||||
for op in init_net.op:
|
||||
assert not op.input
|
||||
try:
|
||||
data_type, field_name = {
|
||||
'GivenTensorFill': (TensorProto.FLOAT, 'floats'),
|
||||
'GivenTensorInt64Fill': (TensorProto.INT64, 'ints'),
|
||||
'GivenTensorIntFill': (TensorProto.INT32, 'ints'),
|
||||
'GivenTensorBoolFill': (TensorProto.BOOL, 'ints'),
|
||||
'GivenTensorStringFill': (TensorProto.STRING, 'strings'),
|
||||
}[op.type]
|
||||
except KeyError:
|
||||
raise RuntimeError(
|
||||
"Can not translate init_net with operator '{}' "
|
||||
"to initializer".format(op.type)
|
||||
)
|
||||
raw = (data_type != TensorProto.STRING)
|
||||
args = {a.name: a for a in op.arg}
|
||||
vals = getattr(args['values'], field_name)
|
||||
if raw:
|
||||
vals = np.asarray(
|
||||
vals,
|
||||
dtype=mapping.TENSOR_TYPE_TO_NP_TYPE[data_type]).tobytes()
|
||||
initializer.append(make_tensor(
|
||||
name=op.output[0],
|
||||
data_type=data_type,
|
||||
dims=args['shape'].ints,
|
||||
vals=vals,
|
||||
raw=raw,
|
||||
))
|
||||
output_names.extend(op.output)
|
||||
initializer = [numpy_helper.from_array(ws.FetchBlob(name), name=name)
|
||||
for name in sorted(set(output_names))]
|
||||
return initializer
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -390,7 +390,6 @@ class TestCaffe2End2End(TestCase):
|
|||
def test_inception_v2(self):
|
||||
self._test_net('inception_v2')
|
||||
|
||||
@unittest.skip('Need to add support for ConstantFill operator')
|
||||
def test_squeezenet(self):
|
||||
self._test_net('squeezenet')
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user