mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
* Move onnx-caffe2 inside caffe2 * Update to the lastest onnx-caffe2 and update jenkins env * Rename onnx_caffe2 to onnx * Add __init__.py to caffe2/python/onnx * Change CI check variable to JENKINS_URL * Cherrypick recent onnx-caffe2 update
78 lines
3.5 KiB
Python
78 lines
3.5 KiB
Python
# Copyright (c) 2016-present, Facebook, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
##############################################################################
|
|
|
|
## @package onnx
|
|
# Module caffe2.python.onnx.backend_rep
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import core, workspace
|
|
from caffe2.proto import caffe2_pb2
|
|
from onnx.backend.base import BackendRep, namedtupledict
|
|
|
|
class Caffe2Rep(BackendRep):
|
|
def __init__(self, init_net, predict_net, workspace, uninitialized):
|
|
super(Caffe2Rep, self).__init__()
|
|
self.init_net = init_net
|
|
self.predict_net = predict_net
|
|
self.workspace = workspace
|
|
# The list of uninitialized external_inputs in workspace, we need this to
|
|
# pair the name with given sequence inputs.
|
|
self.uninitialized = uninitialized
|
|
self.nets_created = False
|
|
self.ran_init_net = False
|
|
|
|
@property
|
|
def _name_scope(self):
|
|
if self.predict_net.device_option.device_type == caffe2_pb2.CUDA:
|
|
return 'gpu_{}'.format(self.predict_net.device_option.cuda_gpu_id)
|
|
return ''
|
|
|
|
def run(self, inputs, **kwargs):
|
|
super(Caffe2Rep, self).run(inputs, **kwargs)
|
|
with self.workspace:
|
|
with core.DeviceScope(self.predict_net.device_option):
|
|
if isinstance(inputs, dict):
|
|
with core.NameScope(self._name_scope):
|
|
for key, value in inputs.items():
|
|
workspace.FeedBlob(key, value)
|
|
elif isinstance(inputs, list) or isinstance(inputs, tuple):
|
|
if len(self.uninitialized) != len(inputs):
|
|
raise RuntimeError('Expected {} values for uninitialized '
|
|
'graph inputs ({}), but got {}.'.format(
|
|
len(self.uninitialized),
|
|
', '.join(self.uninitialized),
|
|
len(inputs)))
|
|
for i, value in enumerate(inputs):
|
|
# namescope already baked into protobuf
|
|
workspace.FeedBlob(self.uninitialized[i], value)
|
|
else:
|
|
# single input
|
|
workspace.FeedBlob(self.uninitialized[0], inputs)
|
|
if not self.nets_created:
|
|
workspace.CreateNet(self.init_net)
|
|
workspace.CreateNet(self.predict_net)
|
|
self.nets_created = True
|
|
if not self.ran_init_net:
|
|
workspace.RunNet(self.init_net.name)
|
|
self.ran_init_net = True
|
|
workspace.RunNet(self.predict_net.name)
|
|
output_values = [workspace.FetchBlob(name)
|
|
for name in self.predict_net.external_output]
|
|
return namedtupledict('Outputs',
|
|
self.predict_net.external_output)(*output_values)
|