mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Added mobile exporter
Summary: Basically takes in a live net and creates an init_net and predict_net which can be written to file and run in Predictor Reviewed By: salexspb Differential Revision: D4989425 fbshipit-source-id: 8052065da9ed763d48bd9e1e19f7697ef60a2829
This commit is contained in:
parent
db1d62caf7
commit
c55be38e63
71
caffe2/python/predictor/mobile_exporter.py
Normal file
71
caffe2/python/predictor/mobile_exporter.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
## @package mobile_exporter
|
||||
# Module caffe2.python.mobile_exporter
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
from caffe2.python import core, utils
|
||||
from caffe2.proto import caffe2_pb2
|
||||
|
||||
|
||||
def Export(workspace, net, params):
|
||||
"""Returns init_net and predict_net suitable for writing to disk
|
||||
and loading into a Predictor"""
|
||||
predict_net = caffe2_pb2.NetDef()
|
||||
predict_net.CopyFrom(net.Proto())
|
||||
init_net = caffe2_pb2.NetDef()
|
||||
# Populate the init_net.
|
||||
ssa, blob_versions = core.get_ssa(net)
|
||||
inputs = []
|
||||
for versioned_inputs, _ in ssa:
|
||||
inputs += [name for name, _ in versioned_inputs]
|
||||
|
||||
input_blobs = [blob_name for blob_name, version in
|
||||
blob_versions.items()
|
||||
if version == 0 and blob_name not in params]
|
||||
# Blobs that are never used as an input to another layer,
|
||||
# i.e. strictly output blobs.
|
||||
output_blobs = [blob_name for blob_name, version in
|
||||
blob_versions.items()
|
||||
if version != 0 and blob_name not in inputs]
|
||||
|
||||
for blob_ref in params:
|
||||
blob_name = str(blob_ref)
|
||||
blob = workspace.FetchBlob(blob_name)
|
||||
init_net.op.extend(
|
||||
[
|
||||
core.CreateOperator(
|
||||
"GivenTensorFill", [], [blob_name],
|
||||
arg=[
|
||||
utils.MakeArgument("shape", blob.shape),
|
||||
utils.MakeArgument("values", blob)
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
# We have to make sure the blob exists in the namespace
|
||||
# and we can do so with fake data. (Which is immediately overwritten
|
||||
# by any typical usage)
|
||||
for blob_name in input_blobs:
|
||||
init_net.op.extend(
|
||||
[
|
||||
core.CreateOperator(
|
||||
"GivenTensorFill", [], [blob_name],
|
||||
arg=[
|
||||
utils.MakeArgument("shape", [1, 1]),
|
||||
utils.MakeArgument("values", [0.0])
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Now we make input/output_blobs line up with what Predictor expects.
|
||||
del predict_net.external_input[:]
|
||||
predict_net.external_input.extend(input_blobs)
|
||||
# For populating weights
|
||||
predict_net.external_input.extend(net.Proto().external_input)
|
||||
# Ensure the output is also consistent with what we want
|
||||
del predict_net.external_output[:]
|
||||
predict_net.external_output.extend(output_blobs)
|
||||
return init_net, predict_net
|
||||
70
caffe2/python/predictor/mobile_exporter_test.py
Normal file
70
caffe2/python/predictor/mobile_exporter_test.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
from caffe2.python.test_util import TestCase
|
||||
from caffe2.python import workspace, brew
|
||||
from caffe2.python.model_helper import ModelHelper
|
||||
from caffe2.python.predictor import mobile_exporter
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestMobileExporter(TestCase):
|
||||
def test_mobile_exporter(self):
|
||||
model = ModelHelper(name="mobile_exporter_test_model")
|
||||
# Test LeNet
|
||||
brew.conv(model, 'data', 'conv1', dim_in=1, dim_out=20, kernel=5)
|
||||
brew.max_pool(model, 'conv1', 'pool1', kernel=2, stride=2)
|
||||
brew.conv(model, 'pool1', 'conv2', dim_in=20, dim_out=50, kernel=5)
|
||||
brew.max_pool(model, 'conv2', 'pool2', kernel=2, stride=2)
|
||||
brew.fc(model, 'pool2', 'fc3', dim_in=50 * 4 * 4, dim_out=500)
|
||||
brew.relu(model, 'fc3', 'fc3')
|
||||
brew.fc(model, 'fc3', 'pred', 500, 10)
|
||||
brew.softmax(model, 'pred', 'out')
|
||||
|
||||
# Create our mobile exportable networks
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
init_net, predict_net = mobile_exporter.Export(
|
||||
workspace, model.net, model.params
|
||||
)
|
||||
|
||||
# Populate the workspace with data
|
||||
np_data = np.random.rand(1, 1, 28, 28).astype(np.float32)
|
||||
workspace.FeedBlob("data", np_data)
|
||||
|
||||
workspace.CreateNet(model.net)
|
||||
workspace.RunNet(model.net)
|
||||
ref_out = workspace.FetchBlob("out")
|
||||
|
||||
# Clear the workspace
|
||||
workspace.ResetWorkspace()
|
||||
|
||||
# Populate the workspace with data
|
||||
workspace.RunNetOnce(init_net)
|
||||
# Fake "data" is populated by init_net, we have to replace it
|
||||
workspace.FeedBlob("data", np_data)
|
||||
|
||||
# Overwrite the old net
|
||||
workspace.CreateNet(predict_net, True)
|
||||
workspace.RunNet(predict_net.name)
|
||||
manual_run_out = workspace.FetchBlob("out")
|
||||
np.testing.assert_allclose(
|
||||
ref_out, manual_run_out, atol=1e-10, rtol=1e-10
|
||||
)
|
||||
|
||||
# Clear the workspace
|
||||
workspace.ResetWorkspace()
|
||||
|
||||
# Predictor interface test (simulates writing to disk)
|
||||
predictor = workspace.Predictor(
|
||||
init_net.SerializeToString(), predict_net.SerializeToString()
|
||||
)
|
||||
|
||||
# Output is a vector of outputs but we only care about the first and only result
|
||||
predictor_out = predictor.run([np_data])
|
||||
assert len(predictor_out) == 1
|
||||
predictor_out = predictor_out[0]
|
||||
|
||||
np.testing.assert_allclose(
|
||||
ref_out, predictor_out, atol=1e-10, rtol=1e-10
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user