mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15472 Create a path to pass serialized C2 protobuf instead of ONNX during ONNXIFI flow Reviewed By: houseroad Differential Revision: D13536603 fbshipit-source-id: 7d016474f4beedbda480ed2e2c0004af7868aafe
54 lines
1.6 KiB
Python
54 lines
1.6 KiB
Python
## @package onnx
|
|
#Module caffe2.python.onnx.onnxifi
|
|
|
|
"""
|
|
ONNXIFI a Caffe2 net
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.proto import caffe2_pb2
|
|
from caffe2.python import core, workspace
|
|
import caffe2.python._import_c_extension as C
|
|
import numpy as np
|
|
|
|
|
|
def onnxifi_caffe2_net(
|
|
pred_net,
|
|
input_shapes,
|
|
infer_shapes=False,
|
|
debug=False,
|
|
use_onnx=True):
|
|
"""
|
|
Transform the caffe2_net by collapsing ONNXIFI-runnable nodes into Onnxifi c2 ops
|
|
"""
|
|
# Inject an fake input tensor to help popluate the shape if we
|
|
# do not do shape inference
|
|
shape_hints = {}
|
|
external_inputs = []
|
|
if not infer_shapes:
|
|
for k, v in input_shapes.items():
|
|
need_input_tensor = True
|
|
if workspace.HasBlob(k):
|
|
itensor = workspace.FetchBlob(k)
|
|
if itensor.shape == v:
|
|
need_input_tensor = False
|
|
if need_input_tensor:
|
|
workspace.FeedBlob(k, np.random.randn(*v).astype(np.float32))
|
|
external_inputs.append(k)
|
|
|
|
for k, v in input_shapes.items():
|
|
shape_hints[k] = v
|
|
pred_net_str = C.onnxifi(pred_net.SerializeToString(),
|
|
external_inputs,
|
|
shape_hints,
|
|
infer_shapes,
|
|
debug,
|
|
use_onnx)
|
|
pred_net_cut = caffe2_pb2.NetDef()
|
|
pred_net_cut.ParseFromString(pred_net_str)
|
|
return pred_net_cut
|