import io import onnx import torch.onnx from caffe2.python.core import BlobReference, Net from caffe2.python.onnx.backend import Caffe2Backend _next_idx = 0 # Clone net takes a dict instead of a lambda # It should probably take a lambda, it is more flexible # We fake dict here class _FakeDict: def __init__(self, fn): self.fn = fn def get(self, name, _): return self.fn(name) def PyTorchModule(helper, model, sample_arguments, caffe2_inputs, prefix_name=None): """ Embed an ONNX-exportable PyTorch Model into a Caffe2 model being built. Args: helper (caffe2.python.core.ModelHelder): the model helper where this imported network should be inserted model (torch.nn.Module): the model to be exported sample_arguments (tuple of arguments): the inputs to the model, e.g., such that ``model(*args)`` is a valid invocation of the model. Any non-Variable arguments will be hard-coded into the exported model; any Variable arguments will become inputs of the exported model, in the order they occur in args. If args is a Variable, this is equivalent to having called it with a 1-ary tuple of that Variable. (Note: passing keyword arguments to the model is not currently supported. Give us a shout if you need it.) caffe2_inputs (list of str or caffe2.python.core.BlobReference): the caffe2 Blobs that should be inputs to this network. Must be the same length as sample_arguments prefix_name: prefix name to add to each member of the blob, if None then a fresh prefix pytorch_input_N/ is used Returns: A tuple of caffe2.python.core.BlobReference objects referring to the models outputs, or a single BlobReference when the model returns a single value. """ if prefix_name is None: global _next_idx prefix_name = "pytorch_import_" + str(_next_idx) + "/" _next_idx += 1 # TODO: handle the case where model cannot be exported # and embed as a Python op in Caffe2 f = io.BytesIO() torch.onnx.export(model, sample_arguments, f, export_params=True) onnx_model = onnx.load(io.BytesIO(f.getvalue())) init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model) initialized = {x.name for x in onnx_model.graph.initializer} uninitialized_inputs = { x.name: i for i, x in enumerate(onnx_model.graph.input) if x.name not in initialized } if len(uninitialized_inputs) != len(caffe2_inputs): raise ValueError( f"Expected {len(uninitialized_inputs)} inputs but found {len(caffe2_inputs)}" ) def remap_blob_name(name): if name in uninitialized_inputs: idx = uninitialized_inputs[name] return str(caffe2_inputs[idx]) return prefix_name + name predict_net = Net(predict_net).Clone("anon", _FakeDict(remap_blob_name)) helper.net.AppendNet(predict_net) init_net = Net(init_net).Clone("anon", _FakeDict(remap_blob_name)) helper.param_init_net.AppendNet(init_net) results = tuple( BlobReference(remap_blob_name(x.name), helper.net) for x in onnx_model.graph.output ) return results