mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Adds basic nomnigraph python bindings for quickly playing with the graphs. Reviewed By: duc0 Differential Revision: D9441936 fbshipit-source-id: fd70f8ea279b28c766e40f124008800acd94bddd
57 lines
1.7 KiB
Python
57 lines
1.7 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import caffe2.python._import_c_extension as C
|
|
from caffe2.python import core
|
|
from caffe2.proto import caffe2_pb2
|
|
import os
|
|
from subprocess import Popen, PIPE
|
|
import errno
|
|
|
|
|
|
class NNModule(object):
|
|
def __init__(self, net=None):
|
|
if net is not None:
|
|
if isinstance(net, core.Net):
|
|
self._NNModule = C.NNModuleFromProtobuf(net.Proto().SerializeToString())
|
|
elif isinstance(net, caffe2_pb2.NetDef):
|
|
self._NNModule = C.NNModuleFromProtobuf(net.SerializeToString())
|
|
else:
|
|
raise Exception(
|
|
"NNModule can be constructed with core.Net or caffe2_pb2.NetDef types"
|
|
)
|
|
else:
|
|
self._NNModule = C.NNModule()
|
|
|
|
@property
|
|
def dataFlow(self):
|
|
return self._NNModule.dataFlow()
|
|
|
|
def dumpDataFlow(self):
|
|
s = self._NNModule.dotString()
|
|
cmd_exists = lambda x: any(
|
|
os.access(os.path.join(path, x), os.X_OK)
|
|
for path in os.environ["PATH"].split(os.pathsep)
|
|
)
|
|
if cmd_exists("graph-easy"):
|
|
p = Popen("graph-easy", stdin=PIPE)
|
|
try:
|
|
p.stdin.write(s.encode("utf-8"))
|
|
except IOError as e:
|
|
if e.errno == errno.EPIPE or e.errno == errno.EINVAL:
|
|
pass
|
|
else:
|
|
# Raise any other error.
|
|
raise
|
|
|
|
p.stdin.close()
|
|
p.wait()
|
|
else:
|
|
print(s)
|
|
|
|
|
|
NeuralNetOperator = C.NeuralNetOperator
|
|
NeuralNetData = C.NeuralNetData
|