pytorch/caffe2/python/onnx/backend_rep.py
Yang Liu 7d7d336c45 Back out "codemod cuda_gpu_id to device_id"
Summary:
Original commit changeset: f5614a5d2607

D9986213 is causing Multifeed Aggregator a [huge performance different](https://our.intern.facebook.com/intern/ads/analyze_canary/412951953278781781/) and is blocking aggregator push since last Friday night: https://fburl.com/feedtools/b6izvwjz
We need to land this revert ASAP to unblock aggregator push.

Reviewed By: orionr

Differential Revision: D10123245

fbshipit-source-id: d83da8e00a1250f5d09811a0a587c127e377aab2
2018-10-01 11:31:14 -07:00

62 lines
2.8 KiB
Python

## @package onnx
# Module caffe2.python.onnx.backend_rep
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core
from caffe2.proto import caffe2_pb2
from onnx.backend.base import BackendRep, namedtupledict
class Caffe2Rep(BackendRep):
def __init__(self, init_net, predict_net, workspace, uninitialized):
super(Caffe2Rep, self).__init__()
self.init_net = init_net
self.predict_net = predict_net
self.workspace = workspace
# The list of uninitialized external_inputs in workspace, we need this to
# pair the name with given sequence inputs.
self.uninitialized = uninitialized
self.nets_created = False
self.ran_init_net = False
@property
def _name_scope(self):
if self.predict_net.device_option.device_type == caffe2_pb2.CUDA:
return 'gpu_{}'.format(self.predict_net.device_option.cuda_gpu_id)
return ''
def run(self, inputs, **kwargs):
super(Caffe2Rep, self).run(inputs, **kwargs)
with core.DeviceScope(self.predict_net.device_option):
if isinstance(inputs, dict):
with core.NameScope(self._name_scope):
for key, value in inputs.items():
self.workspace.FeedBlob(key, value)
elif isinstance(inputs, list) or isinstance(inputs, tuple):
if len(self.uninitialized) != len(inputs):
raise RuntimeError('Expected {} values for uninitialized '
'graph inputs ({}), but got {}.'.format(
len(self.uninitialized),
', '.join(self.uninitialized),
len(inputs)))
for i, value in enumerate(inputs):
# namescope already baked into protobuf
self.workspace.FeedBlob(self.uninitialized[i], value)
else:
# single input
self.workspace.FeedBlob(self.uninitialized[0], inputs)
if not self.nets_created:
self.workspace.CreateNet(self.init_net)
self.workspace.CreateNet(self.predict_net)
self.nets_created = True
if not self.ran_init_net:
self.workspace.RunNet(self.init_net.name)
self.ran_init_net = True
self.workspace.RunNet(self.predict_net.name)
output_values = [self.workspace.FetchBlob(name)
for name in self.predict_net.external_output]
return namedtupledict('Outputs',
self.predict_net.external_output)(*output_values)