mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30679 Caffe2 expects quantized ops to be in NHWC format while pytorch inputs are in NCHW. Add a jit pass to insert permutes to convert from nchw2nhwc before each conv op and add nhwc2nchw permute after the conv op. Using graph rewriter to find consecutive redundant permutes and remove them from the graph Test Plan: python test/onnx/test_pytorch_onnx_caffe2_quantized.py TestQuantizedOps Imported from OSS Differential Revision: D18790518 fbshipit-source-id: 4dd39cf0b31b21f5586c0edfdce2260d4e245112
66 lines
2.9 KiB
Python
66 lines
2.9 KiB
Python
# @package onnx
|
|
# Module caffe2.python.onnx.backend_rep
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import core
|
|
from caffe2.proto import caffe2_pb2
|
|
from onnx.backend.base import BackendRep, namedtupledict
|
|
|
|
class Caffe2Rep(BackendRep):
|
|
def __init__(self, init_net, predict_net, workspace, uninitialized):
|
|
super(Caffe2Rep, self).__init__()
|
|
self.init_net = init_net
|
|
self.predict_net = predict_net
|
|
self.workspace = workspace
|
|
# The list of uninitialized external_inputs in workspace, we need this to
|
|
# pair the name with given sequence inputs.
|
|
self.uninitialized = uninitialized
|
|
self.nets_created = False
|
|
self.ran_init_net = False
|
|
|
|
@property
|
|
def _name_scope(self):
|
|
if self.predict_net.device_option.device_type == caffe2_pb2.CUDA:
|
|
return 'gpu_{}'.format(self.predict_net.device_option.device_id)
|
|
return ''
|
|
|
|
def run(self, inputs, **kwargs):
|
|
super(Caffe2Rep, self).run(inputs, **kwargs)
|
|
with core.DeviceScope(self.predict_net.device_option):
|
|
if isinstance(inputs, dict):
|
|
with core.NameScope(self._name_scope):
|
|
for key, value in inputs.items():
|
|
self.workspace.FeedBlob(key, value)
|
|
elif isinstance(inputs, list) or isinstance(inputs, tuple):
|
|
if len(self.uninitialized) != len(inputs):
|
|
raise RuntimeError('Expected {} values for uninitialized '
|
|
'graph inputs ({}), but got {}.'.format(
|
|
len(self.uninitialized),
|
|
', '.join(self.uninitialized),
|
|
len(inputs)))
|
|
for i, value in enumerate(inputs):
|
|
# namescope already baked into protobuf
|
|
self.workspace.FeedBlob(self.uninitialized[i], value)
|
|
else:
|
|
# single input
|
|
self.workspace.FeedBlob(self.uninitialized[0], inputs)
|
|
if not self.nets_created:
|
|
self.workspace.CreateNet(self.init_net)
|
|
self.workspace.CreateNet(self.predict_net)
|
|
self.nets_created = True
|
|
if not self.ran_init_net:
|
|
self.workspace.RunNet(self.init_net.name)
|
|
self.ran_init_net = True
|
|
self.workspace.RunNet(self.predict_net.name)
|
|
output_values = []
|
|
for name in self.predict_net.external_output:
|
|
try:
|
|
output_values.append(self.workspace.FetchBlob(name))
|
|
except Exception:
|
|
output_values.append(self.workspace.FetchInt8Blob(name))
|
|
return namedtupledict('Outputs',
|
|
self.predict_net.external_output)(*output_values)
|