#include "caffe2/core/context.h" #include "caffe2/core/tensor.h" #include "caffe2/core/types.h" #include "caffe2/opt/converter.h" #include "caffe2/opt/distributed.h" #include "caffe2/proto/caffe2.pb.h" #include "caffe2/python/dlpack.h" #include "caffe2/python/pybind_state_registry.h" #include "caffe2/utils/proto_utils.h" #include "nomnigraph/Converters/Dot.h" #include "nomnigraph/Graph/Algorithms.h" #include "nomnigraph/Representations/NeuralNet.h" #include #include using ListCasterBase = pybind11::detail::list_caster< std::vector, nom::repr::NNGraph::NodeRef>; namespace pybind11 { namespace detail { template <> struct type_caster> : ListCasterBase { static handle cast( const std::vector& src, return_value_policy, handle parent) { return ListCasterBase::cast(src, return_value_policy::reference, parent); } static handle cast( const std::vector* src, return_value_policy pol, handle parent) { return cast(*src, pol, parent); } }; } // namespace detail } // namespace pybind11 namespace caffe2 { namespace python { using namespace nom::repr; namespace { std::map NNPrinter( typename nom::repr::NNGraph::NodeRef node) { std::map labelMap; assert(node->data() && "Node doesn't have data, can't render it"); if (isa(node->data())) { auto* op = dyn_cast(node->data().get()); labelMap["label"] = op->getName(); labelMap["shape"] = "box"; } else if (isa(node->data())) { auto tensor = dyn_cast(node->data().get()); labelMap["label"] = tensor->getName(); } return labelMap; }; using Graph = nom::Graph; std::map GraphPrinter(typename Graph::NodeRef node) { std::map labelMap; assert(node->data() && "Node doesn't have data, can't render it"); labelMap["label"] = py::str(node->data()); return labelMap; }; } // namespace void addNomnigraphMethods(pybind11::module& m) { // Generic Graph methods py::class_ graph(m, "Graph"); py::class_> node(m, "Node"); py::class_> edge(m, "Edge"); graph.def(py::init<>()) .def( "__repr__", [](Graph* g) { return nom::converters::convertToDotString(g, GraphPrinter); }) .def( "createEdge", [](Graph* g, Graph::NodeRef a, Graph::NodeRef b) { return g->createEdge(a, b); }, py::return_value_policy::reference_internal) .def( "createNode", [](Graph* g, py::object obj) { return g->createNode(std::move(obj)); }, py::return_value_policy::reference_internal); // NNModule methods m.def("NNModuleFromProtobuf", [](py::bytes def) { caffe2::NetDef proto; CAFFE_ENFORCE(ParseProtoFromLargeString(def.cast(), &proto)); std::vector ns; auto nn = caffe2::convertToNNModule(proto, false, &ns); return std::pair>( std::move(nn), ns); }); m.def( "NNModuleFromProtobufDistributed", [](py::bytes def, std::map blobToDeviceMap) { std::map m; for (const auto& el : blobToDeviceMap) { caffe2::DeviceOption d; CAFFE_ENFORCE( ParseProtoFromLargeString(el.second.cast(), &d)); m[el.first] = d; } caffe2::NetDef proto; CAFFE_ENFORCE( ParseProtoFromLargeString(def.cast(), &proto)); return caffe2::convertToNNModule(proto, m); }); m.def("replaceProducer", &nn::replaceProducer); m.def("replaceAllUsesWith", &nn::replaceAllUsesWith); m.def("replaceAsConsumer", &nn::replaceAsConsumer); py::class_ nnmodule(m, "NNModule"); nnmodule.def(py::init<>()) .def( "dataFlow", [](NNModule* nn) -> NNGraph* { return &nn->dataFlow; }, py::return_value_policy::reference_internal) .def( "createUniqueDataNode", &NNModule::createUniqueDataNode, py::return_value_policy::reference_internal) .def( "convertToCaffe2Proto", [](NNModule& nn, py::object def) { CAFFE_ENFORCE( pybind11::hasattr(def, "SerializeToString"), "convertToCaffe2Proto takes either no args", "a NetDef"); auto str = def.attr("SerializeToString")(); caffe2::NetDef proto; proto.ParseFromString(py::bytes(str)); auto new_proto = caffe2::convertToCaffe2Proto(nn, proto); std::string out; new_proto.SerializeToString(&out); return py::bytes(out); }) .def( "getExecutionOrder", [](NNModule& nn) { nn::coalesceInsertedDataDependencies(&nn); std::vector out; auto sccs = nom::algorithm::tarjans(&nn.controlFlow); for (const auto& scc : sccs) { for (const auto& bb : scc.getNodes()) { for (const auto& instr : bb->data().getInstructions()) { out.emplace_back(instr); } } } return out; }, py::return_value_policy::reference_internal) .def("replaceSubgraph", &NNModule::replaceSubgraph) .def("deleteSubgraph", &NNModule::deleteSubgraph); auto getTensors = [](NNGraph* g) { return nn::nodeIterator(*g); }; auto getOperators = [](NNGraph* g) { return nn::nodeIterator(*g); }; // NNGraph methods py::class_ nngraph(m, "NNGraph"); nngraph .def( "__repr__", [](NNGraph* g) { return nom::converters::convertToDotString(g, NNPrinter); }) .def( "createEdge", [](NNGraph* g, NNGraph::NodeRef a, NNGraph::NodeRef b) { CAFFE_ENFORCE( (nn::is(a) && nn::is(b)) || (nn::is(b) && nn::is(a)), "Edges must exist between NeuralNetOperator and NeuralNetData"); g->createEdge(a, b); }) .def("deleteEdge", &NNGraph::deleteEdge) .def( "deleteEdge", [](NNGraph* g, NNGraph::NodeRef a, NNGraph::NodeRef b) { auto edge = g->getEdgeIfExists(a, b); if (edge) { g->deleteEdge(edge); } }) .def( "createNode", [](NNGraph* g, GenericOperator& op) { return g->createNode( nom::util::make_unique(op.getName())); }, py::return_value_policy::reference_internal) .def( "createNode", [](NNGraph* g, nom::repr::Tensor& tensor) { return g->createNode( nom::util::make_unique(tensor.getName())); }, py::return_value_policy::reference_internal) .def( "createNode", [](NNGraph* g, py::object op_def) { CAFFE_ENFORCE( pybind11::hasattr(op_def, "SerializeToString"), "createNode takes either OperatorDef", "or ng.NeuralNetOperator"); auto str = op_def.attr("SerializeToString")(); OperatorDef op; op.ParseFromString(py::bytes(str)); if (op.input().size() || op.output().size()) { LOG(WARNING) << "Input and output specifications are " << "dropped when converting a single operator to nomnigraph. " << "Use ng.NNModule(NetDef&) to preserve these."; } return g->createNode(convertToNeuralNetOperator(op)); }, py::return_value_policy::reference_internal) .def("deleteNode", &NNGraph::deleteNode) .def( "replaceNode", [](NNGraph* g, NNGraph::NodeRef old_node, NNGraph::NodeRef new_node) { g->replaceNode(old_node, new_node); }) .def( "getMutableNodes", &NNGraph::getMutableNodes, py::return_value_policy::reference_internal) .def_property_readonly( "nodes", &NNGraph::getMutableNodes, py::return_value_policy::reference_internal) .def_property_readonly( "operators", getOperators, py::return_value_policy::reference_internal) .def_property_readonly( "tensors", getTensors, py::return_value_policy::reference_internal); // Node level methods using NodeType = nom::Node>; py::class_ noderef(m, "NodeRef"); auto getName = [](NNGraph::NodeRef n) { if (nn::is(n)) { return nn::get(n)->getName(); } else if (nn::is(n)) { return nn::get(n)->getName(); } return std::string("Unknown"); }; auto getType = [](NNGraph::NodeRef n) { if (nn::is(n)) { return "Tensor"; } else if (nn::is(n)) { return "Operator"; } return "Unknown"; }; auto getOperator = [](NNGraph::NodeRef n) { CAFFE_ENFORCE(nn::is(n)); return nn::get(n); }; auto getTensor = [](NNGraph::NodeRef n) { CAFFE_ENFORCE(nn::is(n)); return nn::get(n); }; auto getInputs = [](NNGraph::NodeRef n) { CAFFE_ENFORCE(nn::is(n)); return nn::getInputs(n); }; auto getOutputs = [](NNGraph::NodeRef n) { CAFFE_ENFORCE(nn::is(n)); return nn::getOutputs(n); }; auto getProducer = [](NNGraph::NodeRef n) { CAFFE_ENFORCE(nn::is(n)); return nn::getProducer(n); }; auto getConsumers = [](NNGraph::NodeRef n) { CAFFE_ENFORCE(nn::is(n)); return nn::getConsumers(n); }; auto setAnnotation = [](NNGraph::NodeRef n, Caffe2Annotation& annot) { auto* nnOp = nn::get(n); nnOp->setAnnotation(nom::util::make_unique(annot)); }; auto getAnnotation = [](NNGraph::NodeRef n) { return getOrAddCaffe2Annotation(n); }; noderef .def( "isOperator", [](NNGraph::NodeRef n) { return nn::is(n); }) .def( "isTensor", [](NNGraph::NodeRef n) { return nn::is(n); }) .def("getType", getType) .def_property_readonly("type", getType) .def("getName", getName) .def_property_readonly("name", getName) .def( "getOperator", getOperator, py::return_value_policy::reference_internal) .def("getTensor", getTensor, py::return_value_policy::reference_internal) .def_property_readonly( "operator", getOperator, py::return_value_policy::reference) .def_property_readonly( "tensor", getTensor, py::return_value_policy::reference) .def("getInputs", getInputs, py::return_value_policy::reference) .def("getOutputs", getOutputs, py::return_value_policy::reference) .def("hasProducer", [](NNGraph::NodeRef n) { return nn::hasProducer(n); }) .def("getProducer", getProducer, py::return_value_policy::reference) .def("getConsumers", getConsumers, py::return_value_policy::reference) .def_property_readonly( "inputs", getInputs, py::return_value_policy::reference) .def_property_readonly( "outputs", getOutputs, py::return_value_policy::reference) .def_property_readonly( "producer", getProducer, py::return_value_policy::reference) .def_property_readonly( "consumers", getConsumers, py::return_value_policy::reference) .def("getAnnotation", getAnnotation, py::return_value_policy::reference) .def("setAnnotation", setAnnotation) .def_property( "annotation", getAnnotation, setAnnotation, py::return_value_policy::reference) .def( "getOperatorPredecessors", [](NNGraph::NodeRef n) { CAFFE_ENFORCE(nn::is(n)); std::vector pred; for (const auto& inEdge : n->getInEdges()) { auto data = inEdge->tail(); if (nn::hasProducer(data)) { pred.emplace_back(nn::getProducer(data)); } } return pred; }, py::return_value_policy::reference) .def( "getOperatorSuccessors", [](NNGraph::NodeRef n) { CAFFE_ENFORCE(nn::is(n)); std::vector succ; for (const auto& outEdge : n->getOutEdges()) { auto data = outEdge->head(); for (const auto& consumer : nn::getConsumers(data)) { succ.emplace_back(consumer); } } return succ; }, py::return_value_policy::reference); py::class_ nnop(m, "NeuralNetOperator"); py::class_ nndata(m, "NeuralNetData"); nnop.def(py::init()).def("getName", &NeuralNetOperator::getName); nndata.def(py::init()).def("getName", &NeuralNetData::getName); // Subgraph matching API py::class_ nnsubgraph(m, "NNSubgraph"); nnsubgraph.def(py::init<>()) .def("__len__", [](NNSubgraph& s) { return s.getNodes().size(); }) .def( "__repr__", [](NNSubgraph* g) { return nom::converters::convertToDotString(*g, NNPrinter); }) .def( "addNode", [](NNSubgraph* sg, NNGraph::NodeRef node) { sg->addNode(node); }) .def( "induceEdges", [](NNSubgraph* sg) { nom::algorithm::induceEdges(sg); }) .def_property_readonly( "nodes", [](NNSubgraph& s) { std::vector out; for (auto n : s.getNodes()) { out.emplace_back(n); } return out; }, py::return_value_policy::reference) .def("hasNode", [](NNSubgraph& s, NNGraph::NodeRef n) { return s.hasNode(n); }); py::class_ nnMatchGraph(m, "NNMatchGraph"); nnMatchGraph.def(py::init<>()); using MatchPredicateType = nom::Node; py::class_ nnMatchPredicate(m, "MatchPredicateRef"); nnMatchGraph .def( "createEdge", [](nn::NNMatchGraph* g, nn::NNMatchGraph::NodeRef a, nn::NNMatchGraph::NodeRef b) { g->createEdge(a, b); }) .def( "createNode", [](nn::NNMatchGraph* g, GenericOperator& op, bool strict) { auto opName = op.getName(); auto match = [opName](NNGraph::NodeRef node) { NOM_REQUIRE_OR_RET_FALSE(nn::is(node)); auto nnOp = nn::get(node); return opName == nnOp->getName(); }; auto node = nn::NNMatchPredicate(match); if (!strict) { node.nonTerminal(); } return g->createNode(std::move(node)); }, py::return_value_policy::reference_internal, py::arg("node"), py::arg("strict") = false) .def( "createNode", [](nn::NNMatchGraph* g, nom::repr::Tensor& tensor, bool strict) { auto node = nn::NNMatchPredicate(nn::is); if (!strict) { node.nonTerminal(); } return g->createNode(std::move(node)); }, py::return_value_policy::reference_internal, py::arg("tensor"), py::arg("strict") = false) .def( "createNode", [](nn::NNMatchGraph* g, bool strict) { auto match = [](NNGraph::NodeRef node) { return true; }; auto node = nn::NNMatchPredicate(match); if (!strict) { node.nonTerminal(); } return g->createNode(std::move(node)); }, py::return_value_policy::reference_internal, py::arg("strict") = false) .def( "getMutableNodes", [](nn::NNMatchGraph* g) { return g->getMutableNodes(); }, py::return_value_policy::reference_internal); m.def("matchSubgraph", [](NNGraph::NodeRef node, nn::NNMatchGraph* mg) { // Get root node or node in root cycle auto match_node = *nom::algorithm::tarjans(mg).back().getNodes().begin(); auto result = mg->isSubgraphMatch(node, match_node, false); if (result.isMatch()) { return *result.getMatchedSubgraph(); } return NNSubgraph(); }); // Annotation API py::class_ annotation(m, "Annotation"); annotation.def(py::init<>()) .def("setDevice", &Caffe2Annotation::setDevice) .def("getDevice", &Caffe2Annotation::getDevice) .def("setDeviceType", &Caffe2Annotation::setDeviceType) .def("getDeviceType", &Caffe2Annotation::getDeviceType) .def("setKeyNode", &Caffe2Annotation::setKeyNode) .def( "getKeyNode", &Caffe2Annotation::getKeyNode, py::return_value_policy::reference) .def("setLengthNode", &Caffe2Annotation::setLengthNode) .def( "getLengthNode", &Caffe2Annotation::getLengthNode, py::return_value_policy::reference) .def("setComponentLevels", &Caffe2Annotation::setComponentLevels) .def("getComponentLevels", &Caffe2Annotation::getComponentLevels) .def("hasDeviceOption", &Caffe2Annotation::hasDeviceOption) .def_property( "device_option", [](Caffe2Annotation& annot) { auto DeviceOption = py::module::import("caffe2.proto.caffe2_pb2") .attr("DeviceOption"); auto proto = annot.getDeviceOption(); std::string serialized_proto; proto.SerializeToString(&serialized_proto); auto py_device_opt = DeviceOption(); py_device_opt.attr("ParseFromString")(py::bytes(serialized_proto)); return py_device_opt; }, [](Caffe2Annotation& annot, py::object& def) { CAFFE_ENFORCE( pybind11::hasattr(def, "SerializeToString"), "device_option can only be set to a DeviceOption"); auto str = def.attr("SerializeToString")(); caffe2::DeviceOption proto; proto.ParseFromString(py::bytes(str)); annot.setDeviceOption(proto); }, py::return_value_policy::reference) .def_property( "operator_def", [](Caffe2Annotation& annot) { auto opDef = py::module::import("caffe2.proto.caffe2_pb2") .attr("OperatorDef"); auto proto = annot.getOperatorDef(); std::string serialized_proto; proto.SerializeToString(&serialized_proto); auto py_op_def= opDef(); py_op_def.attr("ParseFromString")(py::bytes(serialized_proto)); return py_op_def; }, [](Caffe2Annotation& annot, py::object& def) { CAFFE_ENFORCE( pybind11::hasattr(def, "SerializeToString"), "operator_def can only be set to an OperatorDef"); auto str = def.attr("SerializeToString")(); caffe2::OperatorDef proto; proto.ParseFromString(py::bytes(str)); annot.setOperatorDef(proto); }, py::return_value_policy::reference); } REGISTER_PYBIND_ADDITION(addNomnigraphMethods); } // namespace python } // namespace caffe2