from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import caffe2.python._import_c_extension as C from caffe2.python import core from caffe2.proto import caffe2_pb2 import os from subprocess import Popen, PIPE import errno class NNModule(object): def __init__(self, net=None, device_map=None): if net is not None: serialized_proto = None if isinstance(net, core.Net): serialized_proto = net.Proto().SerializeToString() elif isinstance(net, caffe2_pb2.NetDef): serialized_proto = net.SerializeToString() # Distributed if device_map is not None: serialized_device_map = {} for k in device_map: serialized_device_map[k] = device_map[k].SerializeToString() self._NNModule = C.NNModuleFromProtobufDistributed(serialized_proto, serialized_device_map) # Default elif serialized_proto: self._NNModule = C.NNModuleFromProtobuf(serialized_proto) else: raise Exception( "NNModule can be constructed with core.Net or caffe2_pb2.NetDef types" ) else: self._NNModule = C.NNModule() @property def dataFlow(self): return self._NNModule.dataFlow() def convertToCaffe2Proto(self, old_proto=None): if not old_proto: old_proto = caffe2_pb2.NetDef() output = self._NNModule.convertToCaffe2Proto(old_proto) new_proto = caffe2_pb2.NetDef() new_proto.ParseFromString(output) return new_proto def match(self, pattern): for n in self.dataFlow.getMutableNodes(): m = C.matchSubgraph(n, pattern) if m: yield m def render(s): s = str(s) cmd_exists = lambda x: any( os.access(os.path.join(path, x), os.X_OK) for path in os.environ["PATH"].split(os.pathsep) ) if cmd_exists("graph-easy"): p = Popen("graph-easy", stdin=PIPE) try: p.stdin.write(s.encode("utf-8")) except IOError as e: if e.errno == errno.EPIPE or e.errno == errno.EINVAL: pass else: # Raise any other error. raise p.stdin.close() p.wait() else: print(s) NeuralNetOperator = C.NeuralNetOperator Operator = C.NeuralNetOperator NeuralNetData = C.NeuralNetData Data = C.NeuralNetData NNSubgraph = C.NNSubgraph NNMatchGraph = C.NNMatchGraph Graph = C.Graph Annotation = C.Annotation