mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Before this diff RNNOp was using TextFormat for representing steps. This diff is changing RNNOp to prefer NetDef argument instead. To be backward compatible it supports TextFormat for existing models, though we can compile RNNs without TextFormat as well. Reviewed By: salexspb Differential Revision: D5949330 fbshipit-source-id: 9336a8f5ccf30ad8d8e3a7067b9437e1704b1c9f
636 lines
23 KiB
Python
636 lines
23 KiB
Python
# Copyright (c) 2016-present, Facebook, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
##############################################################################
|
|
|
|
## @package model_helper
|
|
# Module caffe2.python.model_helper
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import core, scope, workspace
|
|
from caffe2.python.modeling import parameter_info
|
|
from caffe2.python.modeling.parameter_sharing import (
|
|
parameter_sharing_context,
|
|
)
|
|
from caffe2.python.optimizer_context import (
|
|
OptimizerContext,
|
|
DEFAULT_OPTIM,
|
|
)
|
|
|
|
from future.utils import viewitems, viewkeys
|
|
from itertools import chain
|
|
|
|
import logging
|
|
import six
|
|
|
|
|
|
# _known_working_ops are operators that do not need special care.
|
|
_known_working_ops = [
|
|
"Accuracy",
|
|
"Adam",
|
|
"Add",
|
|
"Adagrad",
|
|
"SparseAdagrad",
|
|
"AveragedLoss",
|
|
"Cast",
|
|
"Checkpoint",
|
|
"ConstantFill",
|
|
"Copy",
|
|
"CopyGPUToCPU",
|
|
"CopyCPUToGPU",
|
|
"DequeueBlobs",
|
|
"EnsureCPUOutput",
|
|
"ExpandDims",
|
|
"Flatten",
|
|
"FlattenToVec",
|
|
"LabelCrossEntropy",
|
|
"LearningRate",
|
|
"MakeTwoClass",
|
|
"MatMul",
|
|
"NCCLAllreduce",
|
|
"NHWC2NCHW",
|
|
"PackSegments",
|
|
"Print",
|
|
"PRelu",
|
|
"ReduceFrontSum",
|
|
"Scale",
|
|
"ScatterWeightedSum",
|
|
"Sigmoid",
|
|
"SortedSegmentSum",
|
|
"Snapshot", # Note: snapshot is deprecated, use Checkpoint
|
|
"Softmax",
|
|
"SoftmaxWithLoss",
|
|
"SquaredL2Distance",
|
|
"Squeeze",
|
|
"StopGradient",
|
|
"Summarize",
|
|
"Tanh",
|
|
"Transpose",
|
|
"UnpackSegments",
|
|
"WeightedSum",
|
|
"YellowFin"
|
|
]
|
|
|
|
|
|
class ModelHelper(object):
|
|
"""A helper model so we can manange models more easily. It contains net def
|
|
and parameter storages. You can add an Operator yourself, e.g.
|
|
|
|
model = model_helper.ModelHelper(name="train_net")
|
|
# init your weight and bias as w and b
|
|
w = model.param_init_net.XavierFill(...)
|
|
b = model.param_init_net.ConstantFill(...)
|
|
fc1 = model.FC([input, w, b], output, **kwargs)
|
|
|
|
or you can use helper functions in brew module without manually
|
|
defining parameter initializations and operators.
|
|
|
|
model = model_helper.ModelHelper(name="train_net")
|
|
fc1 = brew.fc(model, input, output, dim_in, dim_out, **kwargs)
|
|
|
|
"""
|
|
|
|
def __init__(self, name=None, init_params=True, allow_not_known_ops=True,
|
|
skip_sparse_optim=False, param_model=None, arg_scope=None):
|
|
self.name = name or "model"
|
|
self.net = core.Net(self.name)
|
|
|
|
if param_model is not None:
|
|
self.param_init_net = param_model.param_init_net
|
|
self.param_to_grad = param_model.param_to_grad
|
|
self.params = param_model.params
|
|
self._parameters_info = param_model._parameters_info
|
|
self._computed_params = param_model._computed_params
|
|
else:
|
|
self.param_init_net = core.Net(self.name + '_init')
|
|
self.param_to_grad = {}
|
|
self.params = []
|
|
self._parameters_info = {}
|
|
self._computed_params = []
|
|
|
|
self._param_info_deprecated = []
|
|
self._devices = []
|
|
self.gradient_ops_added = False
|
|
self.init_params = init_params
|
|
self.allow_not_known_ops = allow_not_known_ops
|
|
self.skip_sparse_optim = skip_sparse_optim
|
|
self.weights = []
|
|
self.biases = []
|
|
self._arg_scope = {
|
|
'order': "NCHW",
|
|
'use_cudnn': True,
|
|
'cudnn_exhaustive_search': False,
|
|
}
|
|
if arg_scope is not None:
|
|
# Please notice value as None is not acceptable. We are not checking it
|
|
# here because we already have check in MakeArgument.
|
|
self._arg_scope.update(arg_scope)
|
|
|
|
@property
|
|
def arg_scope(self):
|
|
return self._arg_scope
|
|
|
|
def get_name(self):
|
|
return self.name
|
|
|
|
def _infer_param_shape(self, param):
|
|
for op in self.param_init_net.Proto().op:
|
|
if str(param) in op.output:
|
|
for arg in op.arg:
|
|
if arg.name == "shape":
|
|
return list(arg.ints)
|
|
return None
|
|
|
|
def _update_param_info_deprecated(self):
|
|
assert len(self._param_info_deprecated) <= len(self.params)
|
|
for param in self.params[len(self._param_info_deprecated):]:
|
|
if not isinstance(param, core.BlobReference):
|
|
raise ValueError(
|
|
"Param %s must be a BlobReference!" % str(param))
|
|
self._param_info_deprecated.append(parameter_info.ParameterInfo(
|
|
param_id=len(self._param_info_deprecated),
|
|
param=param,
|
|
shape=self._infer_param_shape(param)))
|
|
for info in self._param_info_deprecated:
|
|
info.grad = self.param_to_grad.get(info.name)
|
|
|
|
def _normalize_tags(self, tags):
|
|
tags = tags or []
|
|
return set(tags) if isinstance(tags, list) else set([tags])
|
|
|
|
def create_param(self, param_name, shape, initializer, tags=None):
|
|
"""
|
|
Creates parameter with a given name and initializer.
|
|
|
|
If param_name is instance of BlobRefernce - then this blob will be used
|
|
to store parameter (no any logic will affect it's location).
|
|
|
|
If param_name is instance of a string type, then the final blob will
|
|
be created in the CurrentNameScope with the respect of all parameter
|
|
sharing logic, i.e. 'resolved_name_scope/param_name'.
|
|
|
|
Parameter sharing logic is going to override CurrentNameScope accoring
|
|
to the rules that are specified through ParameterSharing contexts,
|
|
all ParameterSharing contexts are applied recursively until there are no
|
|
extra overrides present, where on each step the best match will be
|
|
applied first.
|
|
|
|
The following examples should clarify the way ParameterSharing logic
|
|
works:
|
|
|
|
As an example if this function is called with parameter 'w':
|
|
a. Call from some scope 'global_scope' with no Parameter sharing:
|
|
'global_scope/w'
|
|
b. Call from scope 'scope_b', with override {'scope_b': 'scope_a'}:
|
|
'scope_a/w'
|
|
c. Call from scope 'scope_a', with override {'scope_a': ''}:
|
|
'scope_a/w'
|
|
d. Call from scope 'scope_b/shared', with overrides
|
|
{'scope_b/shared': 'scope_b', 'scope_b': 'scope_a'}:
|
|
'scope_a/w'
|
|
d. Call from scope 'scope_b/unshared', with overrides
|
|
{'scope_b/shared': 'scope_b', 'scope_b': 'scope_a'}:
|
|
'scope_a/unshared/w'
|
|
"""
|
|
# ParameterSharing works only for case when param_name is instance of
|
|
# a string type. If param_name is a BlobReference - no attempt for
|
|
# ParameterSharing will be applied.
|
|
if isinstance(param_name, core.BlobReference):
|
|
param_name = str(param_name)
|
|
elif isinstance(param_name, six.string_types):
|
|
# Parameter name will be equal to current Namescope that got
|
|
# resolved with the respect of parameter sharing of the scopes.
|
|
param_name = parameter_sharing_context.get_parameter_name(
|
|
param_name)
|
|
else:
|
|
raise "Unsupported type for param_name"
|
|
|
|
if param_name in self._parameters_info:
|
|
assert self._parameters_info[param_name].shape == shape
|
|
return self._parameters_info[param_name].blob
|
|
|
|
param_info = initializer.create_param(
|
|
param_name=core.BlobReference(param_name),
|
|
init_net=self.param_init_net,
|
|
shape=shape,
|
|
)
|
|
optim_context = OptimizerContext.current()
|
|
for tag in self._normalize_tags(tags):
|
|
if optim_context.has_optimizer(tag):
|
|
# param_info will check optimizer has not been set
|
|
param_info.optimizer = optim_context.get_optimizer(tag)
|
|
if not param_info.optimizer and optim_context.has_optimizer(DEFAULT_OPTIM):
|
|
param_info.optimizer = optim_context.get_optimizer(DEFAULT_OPTIM)
|
|
|
|
self._parameters_info[param_name] = param_info
|
|
# Add param to legacy structs as well, so all other functions for
|
|
# parameters are still working.
|
|
self.AddParameter(param_info.blob, tags)
|
|
return param_info.blob
|
|
|
|
def get_param_info(self, param):
|
|
assert isinstance(param, core.BlobReference), \
|
|
"Param {} is not a BlobReference".format(param)
|
|
return self._parameters_info.get(param, None)
|
|
|
|
# This method is deprecated, use create_param method which
|
|
# also does parameter initialization when needed
|
|
def add_param_DEPRECATED(self, param, key=None, shape=None, length=None):
|
|
logging.warning("add_param method is DEPRECATED")
|
|
self._update_param_info_deprecated()
|
|
self.AddParameter(param)
|
|
if key is not None and self.net.input_record() is not None:
|
|
idx = self.net.input_record().field_blobs().index(key)
|
|
key = self.net.input_record().field_names()[idx]
|
|
shape = shape if shape is not None else self._infer_param_shape(param)
|
|
if not isinstance(param, core.BlobReference):
|
|
raise ValueError("Param %s must be a BlobReference!" % str(param))
|
|
self._param_info_deprecated.append(parameter_info.ParameterInfo(
|
|
param_id=len(self._param_info_deprecated),
|
|
param=param,
|
|
shape=shape,
|
|
key=key,
|
|
length=length,
|
|
))
|
|
return self._param_info_deprecated[-1]
|
|
|
|
# This method is deprecated, use get_param_info method
|
|
def param_info(self, grad_type=None, id=None):
|
|
logging.info("param_info method is DEPRECATED")
|
|
self._update_param_info_deprecated()
|
|
if id is not None:
|
|
assert grad_type is None
|
|
info = self._param_info_deprecated[id]
|
|
assert info.param_id == id
|
|
return info
|
|
elif grad_type is not None:
|
|
return [
|
|
info for info in self._param_info_deprecated
|
|
if info.grad_type() == grad_type]
|
|
else:
|
|
return self._param_info_deprecated
|
|
|
|
def AddParameter(self, param, tags=None):
|
|
assert isinstance(param, core.BlobReference)
|
|
tags = self._normalize_tags(tags)
|
|
if parameter_info.ParameterTags.COMPUTED_PARAM in tags:
|
|
self._computed_params.append(param)
|
|
else:
|
|
self.params.append(param)
|
|
|
|
if parameter_info.ParameterTags.WEIGHT in tags:
|
|
self.weights.append(param)
|
|
if parameter_info.ParameterTags.BIAS in tags:
|
|
self.biases.append(param)
|
|
|
|
@staticmethod
|
|
def _NormalizeNamescope(namescope):
|
|
if namescope is None:
|
|
return scope.CurrentNameScope()
|
|
elif namescope == '' or namescope.endswith(scope._NAMESCOPE_SEPARATOR):
|
|
return namescope
|
|
else:
|
|
return namescope + scope._NAMESCOPE_SEPARATOR
|
|
|
|
def GetParams(self, namescope=None, top_scope=False):
|
|
'''
|
|
Returns the params in current namescope
|
|
'''
|
|
namescope = ModelHelper._NormalizeNamescope(namescope)
|
|
|
|
if namescope == '':
|
|
return self.params[:]
|
|
elif top_scope:
|
|
return [
|
|
p for p in self.params
|
|
if p.GetNameScope().startswith(namescope)
|
|
]
|
|
else:
|
|
return [p for p in self.params if
|
|
p.GetNameScope().startswith(namescope)]
|
|
|
|
def Proto(self):
|
|
return self.net.Proto()
|
|
|
|
def InitProto(self):
|
|
return self.param_init_net.Proto()
|
|
|
|
def RunAllOnGPU(self, *args, **kwargs):
|
|
self.param_init_net.RunAllOnGPU(*args, **kwargs)
|
|
self.net.RunAllOnGPU(*args, **kwargs)
|
|
|
|
def CreateDB(self, blob_out, db, db_type, **kwargs):
|
|
dbreader = self.param_init_net.CreateDB(
|
|
[], blob_out, db=db, db_type=db_type, **kwargs)
|
|
return dbreader
|
|
|
|
def AddGradientOperators(self, *args, **kwargs):
|
|
if self.gradient_ops_added:
|
|
raise RuntimeError("You cannot run AddGradientOperators twice.")
|
|
self.Validate()
|
|
|
|
self.gradient_ops_added = True
|
|
self.grad_map = self.net.AddGradientOperators(*args, **kwargs)
|
|
self.param_to_grad = self.get_param_to_grad(self.params)
|
|
|
|
# Populate ParameterInfo for all parameters if missing
|
|
# and add gradient blob information. So optimizers can use it
|
|
for param, grad in self.param_to_grad.items():
|
|
param_info = self.get_param_info(param)
|
|
if param_info:
|
|
param_info.grad = grad
|
|
else:
|
|
self._parameters_info[param] = parameter_info.ParameterInfo(
|
|
param_id=None,
|
|
param=param,
|
|
grad=grad,
|
|
)
|
|
|
|
return self.grad_map
|
|
|
|
def get_param_to_grad(self, params):
|
|
'''
|
|
Given a list of parameters returns a dict from a parameter
|
|
to a corresponding gradient
|
|
'''
|
|
|
|
param_to_grad = {}
|
|
if not self.gradient_ops_added:
|
|
raise RuntimeError("You need to run AddGradientOperators first.")
|
|
# We need to use empty namescope when creating the gradients
|
|
# to prevent duplicating the namescope prefix for gradient blobs.
|
|
for p in params:
|
|
if str(p) in self.grad_map:
|
|
param_to_grad[p] = self.grad_map[str(p)]
|
|
return param_to_grad
|
|
|
|
def GetOptimizationParamInfo(self, params=None):
|
|
'''
|
|
Returns a map for param => grad.
|
|
If params is not specified, all parameters will be considered.
|
|
'''
|
|
if not self.gradient_ops_added:
|
|
raise RuntimeError("Need to call AddGradientOperators first")
|
|
|
|
param_to_grad = self.param_to_grad
|
|
if params:
|
|
param_to_grad = self.get_param_to_grad(params)
|
|
|
|
return [
|
|
self.get_param_info(param) for param, grad in viewitems(param_to_grad)
|
|
if (
|
|
not self.skip_sparse_optim or
|
|
not isinstance(grad, core.GradientSlice)
|
|
)
|
|
]
|
|
|
|
def _Validate(self):
|
|
'''
|
|
Check for duplicate params
|
|
'''
|
|
params_list = [str(p) for p in self.params]
|
|
params_set = set(params_list)
|
|
|
|
dupes = []
|
|
if len(params_set) != len(params_list):
|
|
params_list = sorted(params_list)
|
|
for j, p in enumerate(params_list):
|
|
if j > 0 and params_list[j - 1] == p:
|
|
if p not in dupes:
|
|
dupes.append(p)
|
|
|
|
return dupes
|
|
|
|
def Validate(self):
|
|
dupes = self._Validate()
|
|
assert dupes == [], "Duplicate params: {}".format(dupes)
|
|
|
|
def GetComputedParams(self, namescope=None):
|
|
'''
|
|
Returns the computed params in current namescope. 'Computed params'
|
|
are such parameters that are not optimized via gradient descent but are
|
|
directly computed from data, such as the running mean and variance
|
|
of Spatial Batch Normalization.
|
|
'''
|
|
namescope = ModelHelper._NormalizeNamescope(namescope)
|
|
|
|
if namescope == '':
|
|
return self._computed_params[:]
|
|
else:
|
|
return [p for p in self._computed_params
|
|
if p.GetNameScope().startswith(namescope)]
|
|
|
|
def GetAllParams(self, namescope=None):
|
|
return self.GetParams(namescope) + self.GetComputedParams(namescope)
|
|
|
|
def TensorProtosDBInput(
|
|
self, unused_blob_in, blob_out, batch_size, db, db_type, **kwargs
|
|
):
|
|
"""TensorProtosDBInput."""
|
|
dbreader_name = "dbreader_" + db
|
|
dbreader = self.param_init_net.CreateDB(
|
|
[], dbreader_name,
|
|
db=db, db_type=db_type)
|
|
return self.net.TensorProtosDBInput(
|
|
dbreader, blob_out, batch_size=batch_size)
|
|
|
|
def GetDevices(self):
|
|
assert len(self._devices) > 0, \
|
|
"Use data_parallel_model to run model on multiple GPUs."
|
|
return self._devices
|
|
|
|
def __getattr__(self, op_type):
|
|
"""Catch-all for all other operators, mostly those without params."""
|
|
if op_type.startswith('__'):
|
|
raise AttributeError(op_type)
|
|
|
|
if not core.IsOperator(op_type):
|
|
raise RuntimeError(
|
|
'Method ' + op_type + ' is not a registered operator.' +
|
|
' Did you mean: [' +
|
|
','.join(workspace.C.nearby_opnames(op_type)) + ']'
|
|
)
|
|
if op_type not in _known_working_ops:
|
|
if not self.allow_not_known_ops:
|
|
raise RuntimeError(
|
|
"Operator {} is not known to be safe".format(op_type))
|
|
|
|
logging.warning("You are creating an op that the ModelHelper "
|
|
"does not recognize: {}.".format(op_type))
|
|
return self.net.__getattr__(op_type)
|
|
|
|
def __dir__(self):
|
|
return sorted(set(chain(
|
|
dir(type(self)),
|
|
viewkeys(self.__dict__),
|
|
_known_working_ops
|
|
)))
|
|
|
|
|
|
def ExtractPredictorNet(
|
|
net_proto,
|
|
input_blobs,
|
|
output_blobs,
|
|
device=None,
|
|
renames=None,
|
|
disabled_inputs=None,
|
|
):
|
|
'''
|
|
Takes a model net for training and returns a net which can be
|
|
used for prediction. For example, all gradient operators and
|
|
input operators are removed.
|
|
@param net_proto protobuf of the net you want to process (net.Proto())
|
|
@param input_blobs list/set of blob names that are the inputs of predictor
|
|
@param output_blobs list/set of blob names that are outputs of predictor
|
|
@param device optional device option that is assigned
|
|
@param renames dictionary of blob name to a new name (optional)
|
|
@param disabled_inputs optional set of blobs that are 'switched off'. This
|
|
will cause branches with those blobs as inputs to be removed
|
|
'''
|
|
predict_net = core.Net(net_proto.name + "_predict")
|
|
predict_proto = predict_net.Proto()
|
|
|
|
orig_external_inputs = set(net_proto.external_input)
|
|
orig_external_outputs = set(net_proto.external_output)
|
|
input_blobs = {str(b) for b in input_blobs}
|
|
known_blobs = set(orig_external_inputs).union(input_blobs)
|
|
output_blobs = {str(b) for b in output_blobs}
|
|
external_inputs = set(input_blobs)
|
|
external_outputs = set(output_blobs)
|
|
|
|
if renames is None:
|
|
renames = {}
|
|
|
|
if disabled_inputs is not None:
|
|
known_blobs = known_blobs - set(disabled_inputs)
|
|
|
|
ops = list(net_proto.op)
|
|
|
|
# Find the range of ops that we should include
|
|
try:
|
|
first_op_with_input = min(
|
|
[
|
|
j for j in range(len(ops))
|
|
if input_blobs.intersection(ops[j].input) and ops[j].type !=
|
|
'StopGradient'
|
|
]
|
|
)
|
|
except ValueError:
|
|
raise Exception("No ops with input={}".format(input_blobs))
|
|
try:
|
|
last_op_with_output = max(
|
|
[
|
|
j for j in range(len(ops))
|
|
if output_blobs.intersection(ops[j].output)
|
|
]
|
|
)
|
|
except ValueError:
|
|
raise Exception("No ops with output={}".format(output_blobs))
|
|
|
|
def validate_op(op):
|
|
# Check that the op does not have is_test = 0 set. This is a common
|
|
# pitfall with SpatialBN op, at lest.
|
|
for arg in op.arg:
|
|
if arg.name == "is_test" and arg.i == 0:
|
|
raise Exception(
|
|
"An operator had is_test=0, did you try to extract a " +
|
|
"predictor from a train model (instead of test model)?" +
|
|
" Op was: {}".format(str(op))
|
|
)
|
|
|
|
def rename_list(proto_list):
|
|
# proto lists don't support assignments
|
|
new_list = proto_list[:]
|
|
for j, b in enumerate(new_list):
|
|
if b in renames:
|
|
new_list[j] = renames[b]
|
|
|
|
del proto_list[:]
|
|
proto_list.extend(new_list)
|
|
|
|
# Iterate through the ops and only include those whose inputs
|
|
# we can satisfy.
|
|
for op in ops[first_op_with_input:(last_op_with_output + 1)]:
|
|
if known_blobs.issuperset(op.input):
|
|
|
|
# Special handling for recurrent nets
|
|
# TODO: when standard argument type for "nets" is introduced,
|
|
# this can be more general
|
|
if op.type == 'RecurrentNetwork':
|
|
for arg in op.arg:
|
|
if arg.name == 'backward_step_net':
|
|
arg.ClearField('n')
|
|
elif arg.name == 'step_net':
|
|
for step_op in arg.n.op:
|
|
rename_list(step_op.input)
|
|
rename_list(step_op.output)
|
|
if device is not None:
|
|
step_op.device_option.device_type = device.device_type
|
|
step_op.device_option.cuda_gpu_id = device.cuda_gpu_id
|
|
|
|
rename_list(arg.n.external_input)
|
|
rename_list(arg.n.external_output)
|
|
|
|
# Add additional external inputs
|
|
external_inputs.update(
|
|
set(arg.n.external_input).intersection(
|
|
orig_external_inputs
|
|
)
|
|
)
|
|
|
|
if device is not None:
|
|
op.device_option.device_type = device.device_type
|
|
op.device_option.cuda_gpu_id = device.cuda_gpu_id
|
|
validate_op(op)
|
|
predict_proto.op.extend([op])
|
|
known_blobs.update(op.output)
|
|
external_inputs.update(
|
|
set(op.input).intersection(orig_external_inputs)
|
|
)
|
|
external_outputs.update(
|
|
set(op.output).intersection(orig_external_outputs)
|
|
)
|
|
|
|
else:
|
|
logging.debug(
|
|
"Op {} had unknown inputs: {}".format(
|
|
op.type, set(op.input).difference(known_blobs)
|
|
)
|
|
)
|
|
|
|
# Predictor net's external inputs and outputs include only those
|
|
# that are part of this net.
|
|
predict_proto.external_input.extend(external_inputs)
|
|
predict_proto.external_output.extend(external_outputs)
|
|
|
|
rename_list(predict_proto.external_input)
|
|
rename_list(predict_proto.external_output)
|
|
|
|
renamed_input_blobs = []
|
|
for b in input_blobs:
|
|
if b in renames:
|
|
renamed_input_blobs.append(renames[b])
|
|
else:
|
|
renamed_input_blobs.append(b)
|
|
|
|
for op in predict_proto.op:
|
|
rename_list(op.input)
|
|
rename_list(op.output)
|
|
|
|
return predict_net, list(
|
|
set(predict_proto.external_input) - set(renamed_input_blobs)
|
|
)
|