pytorch/caffe2/python/onnx/onnxifi.py
Yinghai Lu c3a0000864 Support communicating with C2 protobuf in Onnxifi flow (#15472)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15472

Create a path to pass serialized C2 protobuf instead of ONNX during ONNXIFI flow

Reviewed By: houseroad

Differential Revision: D13536603

fbshipit-source-id: 7d016474f4beedbda480ed2e2c0004af7868aafe
2019-01-07 22:12:29 -08:00

54 lines
1.6 KiB
Python

## @package onnx
#Module caffe2.python.onnx.onnxifi
"""
ONNXIFI a Caffe2 net
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.proto import caffe2_pb2
from caffe2.python import core, workspace
import caffe2.python._import_c_extension as C
import numpy as np
def onnxifi_caffe2_net(
pred_net,
input_shapes,
infer_shapes=False,
debug=False,
use_onnx=True):
"""
Transform the caffe2_net by collapsing ONNXIFI-runnable nodes into Onnxifi c2 ops
"""
# Inject an fake input tensor to help popluate the shape if we
# do not do shape inference
shape_hints = {}
external_inputs = []
if not infer_shapes:
for k, v in input_shapes.items():
need_input_tensor = True
if workspace.HasBlob(k):
itensor = workspace.FetchBlob(k)
if itensor.shape == v:
need_input_tensor = False
if need_input_tensor:
workspace.FeedBlob(k, np.random.randn(*v).astype(np.float32))
external_inputs.append(k)
for k, v in input_shapes.items():
shape_hints[k] = v
pred_net_str = C.onnxifi(pred_net.SerializeToString(),
external_inputs,
shape_hints,
infer_shapes,
debug,
use_onnx)
pred_net_cut = caffe2_pb2.NetDef()
pred_net_cut.ParseFromString(pred_net_str)
return pred_net_cut