mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17158 Because of Reshape op, batch size can be changed. This diff addresses first order issue raised from multiple batch size system. We need to export different real_batch_size for different max_batch_size input and attach it to the right output. It also fixes a false exception. Reviewed By: ipiszy Differential Revision: D14099541 fbshipit-source-id: 0fa9e86826f417a11d2b5dd2ee60dff64a7ce8c4
43 lines
1.2 KiB
Python
43 lines
1.2 KiB
Python
## @package onnx
|
|
#Module caffe2.python.onnx.onnxifi
|
|
|
|
"""
|
|
ONNXIFI a Caffe2 net
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.proto import caffe2_pb2
|
|
from caffe2.python import core, workspace
|
|
import caffe2.python._import_c_extension as C
|
|
import numpy as np
|
|
|
|
|
|
def onnxifi_caffe2_net(
|
|
pred_net,
|
|
input_shapes,
|
|
max_batch_size=1,
|
|
max_seq_size=1,
|
|
debug=False,
|
|
use_onnx=True,
|
|
black_list=None):
|
|
"""
|
|
Transform the caffe2_net by collapsing ONNXIFI-runnable nodes into Onnxifi c2 ops
|
|
"""
|
|
shape_hints = {}
|
|
for k, v in input_shapes.items():
|
|
shape_hints[k] = v
|
|
pred_net_str = C.onnxifi(pred_net.SerializeToString(),
|
|
shape_hints,
|
|
black_list if black_list else [],
|
|
max_batch_size,
|
|
max_seq_size,
|
|
debug,
|
|
use_onnx)
|
|
pred_net_cut = caffe2_pb2.NetDef()
|
|
pred_net_cut.ParseFromString(pred_net_str)
|
|
return pred_net_cut
|