#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace jit { using ::c10::Argument; using ::c10::FunctionSchema; using caffe2::serialize::PyTorchStreamReader; using caffe2::serialize::PyTorchStreamWriter; namespace { using autograd::variable_list; bool loadPythonClasses() { // Leaving this code here, because it will likely be useful at some point // PyObject *jit_module = PyImport_ImportModule("torch.jit"); // THPUtils_assert(jit_module, "class loader couldn't access " //"torch.jit module"); // PyObject *jit_dict = PyModule_GetDict(jit_module); return true; } } // anonymous namespace #if defined(_WIN32) void runJITCPPTests() { AT_ERROR("JIT tests not yet supported on Windows"); } #else void runJITCPPTests(); #endif void initJITBindings(PyObject* module) { auto m = py::handle(module).cast(); py::register_exception(m, "JITException"); py::class_ iodescriptor( m, "IODescriptor"); // NOLINT(bugprone-unused-raii) m.def("_jit_init", loadPythonClasses) .def( "_jit_debug_fuser_num_cached_kernel_specs", torch::jit::fuser::debugNumCachedKernelSpecs) .def("_jit_pass_onnx", ToONNX) .def("_jit_pass_lower_all_tuples", LowerAllTuples) .def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX) .def("_jit_pass_fuse", FuseGraph) .def( "_jit_pass_dce", [](std::shared_ptr& g) { return EliminateDeadCode(g->block()); // overload resolution }) .def( "_jit_pass_cse", [](std::shared_ptr& g) { return EliminateCommonSubexpression(g); // overload resolution }) .def( "_jit_pass_expand_fakequant", [](std::shared_ptr& g) { return ExpandFakeQuantNodes(g); }) .def( "_jit_pass_propagate_qinfo", [](std::shared_ptr& g) { return PropagateQuantInfo(g); }) .def( "_jit_pass_insert_observers", [](std::shared_ptr& g, py::function pyObserverFunction) { // Create a new node that would be used in the insert observer pass: // all observer nodes will be cloned from this one. Node* new_node = g->createPythonOp( THPObjectPtr(pyObserverFunction.release().ptr()), "dd", {}); InsertObserverNodes(g, new_node); // We don't need this node anymore, don't forget to remove it. new_node->destroy(); }) .def( "_jit_pass_insert_fakequant", [](std::shared_ptr& g) { return InsertFakeQuantNodes(g); }) .def( "_jit_pass_quantlint", [](std::shared_ptr& g) { return QuantLinting(g); }) .def( "_jit_pass_fold_quant_inputs", [](std::shared_ptr& g) { return FoldQuantNodesIntoInputsOutputs(g); }) .def( "_jit_pass_remove_inplace_ops", [](std::shared_ptr g) { return RemoveInplaceOps(g); }) .def("_jit_pass_constant_pooling", ConstantPooling) .def( "_jit_pass_peephole", [](const std::shared_ptr& g, bool addmm_fusion_enabled) { return PeepholeOptimize(g, addmm_fusion_enabled); }, py::arg("graph"), py::arg("addmm_fusion_enabled") = false) .def( "_jit_pass_canonicalize", [](const std::shared_ptr& g) { return Canonicalize(g); }) .def("_jit_pass_lint", LintGraph) .def( "_jit_pass_shape_analysis", [](std::shared_ptr graph, std::vector inputs, bool with_grad) { setInputTypes( *graph, ArgumentSpec(with_grad, fmap(inputs), inputs.size())); PropagateInputShapes(graph); }) .def( "_jit_pass_complete_shape_analysis", [](std::shared_ptr graph, py::tuple inputs, bool with_grad) { CompleteArgumentSpec spec( with_grad, evilDeprecatedBadCreateStackDoNotUse(inputs, graph->inputs())); auto graph_inputs = graph->inputs(); AT_ASSERT(spec.size() == graph_inputs.size()); for (size_t i = 0; i < graph_inputs.size(); ++i) { graph_inputs[i]->setType(spec.at(i)); } PropagateInputShapes(graph); }) .def("_jit_pass_remove_expands", RemoveExpands) .def("_jit_pass_erase_number_types", EraseNumberTypes) .def("_jit_pass_inline_fork_wait", InlineForkWait) .def("_jit_pass_prepare_division_for_onnx", PrepareDivisionForONNX) .def("_jit_pass_loop_unrolling", UnrollLoops) .def( "_jit_pass_constant_propagation", [](std::shared_ptr& g) { return ConstantPropagation(g); }) .def("_jit_pass_erase_shape_information", EraseShapeInformation) .def( "_jit_pass_create_autodiff_subgraphs", [](std::shared_ptr graph) { CreateAutodiffSubgraphs(graph); }) .def( "_jit_run_cpp_tests", [] { // We have to release the GIL inside this method, because if we // happen to initialize the autograd engine in these tests, the // newly spawned worker threads will try to initialize their // PyThreadState*, and they need the GIL for this. AutoNoGIL _no_gil; return runJITCPPTests(); }) .def( "_jit_flatten", [](py::handle& obj) { auto res = python::flatten(obj); return std::make_pair(res.vars, res.desc); }) .def( "_jit_unflatten", [](autograd::variable_list vars, python::IODescriptor& desc) { return py::reinterpret_steal( python::unflatten(vars, desc)); }) .def("_jit_pass_onnx_block", BlockToONNX) .def("_jit_pass_fixup_onnx_loops", FixupONNXLoops) .def("_jit_pass_canonicalize_ops", CanonicalizeOps) .def("_jit_pass_specialize_autogradzero", specializeAutogradZero) .def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU) .def( "_jit_differentiate", [](Graph& g) { // the python binding slightly differs in semantics // it makes a copy of the input Graph, and works on that // jit::differentiate mutates the input Graph auto g_clone = g.copy(); return differentiate(g_clone); }) .def( "_jit_check_alias_annotation", [](std::shared_ptr g, py::tuple args, const std::string& unqualified_op_name) { auto stack = toStack(args); checkAliasAnnotation(g, std::move(stack), unqualified_op_name); }); // NOLINTNEXTLINE(bugprone-unused-raii) py::class_(m, "CompleteArgumentSpec") .def("__repr__", [](CompleteArgumentSpec& self) { std::ostringstream s; s << self; return s.str(); }); // NOLINTNEXTLINE(bugprone-unused-raii) py::class_(m, "ArgumentSpec"); py::class_(m, "Code").def("grad_executors", [](Code& c) { return py::make_iterator( c.grad_executors().begin(), c.grad_executors().end()); }); py::class_(m, "ExecutionPlanState") .def_property_readonly( "graph", [](ExecutionPlanState& s) { return s.graph; }) .def_property_readonly( "code", [](ExecutionPlanState& s) { return s.code; }); py::class_(m, "Gradient") .def_property_readonly("f", [](Gradient& m) { return m.f; }) .def_property_readonly("df", [](Gradient& m) { return m.df; }) .def_property_readonly( "f_real_outputs", [](Gradient& m) { return m.f_real_outputs; }) .def_property_readonly( "df_input_vjps", [](Gradient& m) { return m.df_input_vjps; }) .def_property_readonly( "df_input_captured_inputs", [](Gradient& m) { return m.df_input_captured_inputs; }) .def_property_readonly( "df_input_captured_outputs", [](Gradient& m) { return m.df_input_captured_outputs; }) .def_property_readonly( "df_output_vjps", [](Gradient& m) { return m.df_output_vjps; }); py::class_(m, "GraphExecutorState") .def_property_readonly( "graph", [](GraphExecutorState& s) { return s.graph; }) .def_property_readonly( "execution_plans", [](GraphExecutorState& s) { return s.execution_plans; }) .def_property_readonly( "fallback", [](GraphExecutorState& s) { return s.fallback; }); py::class_(m, "GraphExecutor", py::dynamic_attr()) .def( py::init([](py::function func, py::tuple inputs, py::function var_name_lookup_fn, bool optimize, bool _force_outplace) { auto graph = tracer::createGraphByTracing( func, toStack(inputs), var_name_lookup_fn, _force_outplace); return GraphExecutor(graph, optimize); }), py::arg("func"), py::arg("inputs"), py::arg("var_name_lookup_fn"), py::arg("optimize") = true, py::arg("_force_outplace") = false) .def( py::init([](std::shared_ptr graph, bool optimize) { return GraphExecutor(std::move(graph), optimize); }), py::arg("graph"), py::arg("optimize") = true) .def( "graph_for", [](GraphExecutor& ge, py::args args) { return ge.graphFor(evilDeprecatedBadCreateStackDoNotUse( args, ge.graph()->inputs())); }) .def_property_readonly( "graph", [](GraphExecutor& ge) { return ge.graph(); }) .def( "get_debug_state", [](GraphExecutor& ge) { return ge.getDebugState(); }) .def("__call__", [](GraphExecutor& ge, py::args args) -> py::object { const auto& graph = ge.graph(); auto stack = evilDeprecatedBadCreateStackDoNotUse(args, graph->inputs()); { AutoNoGIL no_gil_guard; ge.run(stack); } return createPyObjectForStack(std::move(stack)); }); py::class_(m, "PyTorchFileWriter") .def(py::init()) .def( "write_record", [](PyTorchStreamWriter& self, const std::string& name, const char* data, size_t size) { return self.writeRecord(name, data, size); }) .def("write_end_of_file", &PyTorchStreamWriter::writeEndOfFile); py::class_(m, "PyTorchFileReader") .def(py::init()) .def("get_record", [](PyTorchStreamReader& self, const std::string& key) { at::DataPtr data; size_t size; std::tie(data, size) = self.getRecord(key); return py::bytes(reinterpret_cast(data.get()), size); }); m.def( "_jit_get_operation", [](const std::string& qualified_name) { try { auto symbol = Symbol::fromQualString(qualified_name); auto operations = getAllOperatorsFor(symbol); AT_CHECK(!operations.empty(), "No such operator ", qualified_name); AT_CHECK( operations.size() == 1, "Found ", operations.size(), " overloads for operator ", qualified_name, "! Overloads are not supported from Python."); std::shared_ptr op = operations[0]; AT_ASSERT(op != nullptr); std::ostringstream docstring; docstring << "Automatically bound operator '" << qualified_name << "' with schema: " << op->schema(); return py::cpp_function( [op](py::args args, py::kwargs kwargs) { return invokeOperatorFromPython( *op, std::move(args), std::move(kwargs)); }, py::name(qualified_name.c_str()), py::doc(docstring.str().c_str())); } catch (const c10::Error& error) { throw std::runtime_error(error.what_without_backtrace()); } }, py::arg("qualified_name")); m.def("parse_ir", [](const std::string& input) { auto graph = std::make_shared(); script::parseIR(input, &*graph); return graph; }); py::class_(m, "FunctionSchema") .def_property_readonly( "name", [](FunctionSchema& self) { return self.name(); }) .def_property_readonly( "overload_name", [](FunctionSchema& self) { return self.overload_name(); }) .def_property_readonly( "arguments", [](FunctionSchema& self) { return self.arguments(); }) .def_property_readonly( "returns", [](FunctionSchema& self) { return self.returns(); }); py::class_(m, "Argument") .def_property_readonly("name", [](Argument& self) { return self.name(); }) .def_property_readonly("type", [](Argument& self) { return self.type(); }) .def_property_readonly( "N", [](Argument& self) -> py::object { return (self.N()) ? py::cast(*self.N()) : py::none(); }) .def_property_readonly("default_value", [](Argument& self) -> py::object { if (!self.default_value()) return py::none(); IValue v = *self.default_value(); return toPyObject(std::move(v)); }); m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) { auto symbol = Symbol::fromQualString(qualified_name); auto operations = getAllOperatorsFor(symbol); return fmap(operations, [](const std::shared_ptr& op) { return op->schema(); }); }); struct PythonFutureWrapper { explicit PythonFutureWrapper(c10::intrusive_ptr fut) : fut(std::move(fut)) {} c10::intrusive_ptr fut; }; py::class_(m, "Future"); m.def("fork", [](py::args args) { AT_ASSERT(args.size() >= 1); py::function f = py::cast(args[0]); py::tuple args_tup(args.size() - 1); for (size_t i = 1; i < args.size(); ++i) { args_tup[i - 1] = args[i]; } if (jit::tracer::isTracing()) { auto graph = jit::tracer::getTracingState()->graph; auto fork_node = graph->insertNode(graph->create(prim::fork, 1)); auto body_block = fork_node->addBlock(); Value* node_output; py::object py_func_output; auto retval = c10::make_intrusive(); // Insert new trace ops into the fork op's sub-block WithInsertPoint guard(body_block); IValue output_ivalue; { tracer::WithNestedTracingFrame env_guard; // Run the user-supplied function py_func_output = f(*args_tup); // Convert the output of the user-supplied funciton to IValue. The type // information of this IValue is used both to record the correct type in // the trace. output_ivalue = toIValue(py_func_output); Value* out_val = jit::tracer::getNestedValueTrace(output_ivalue); body_block->registerOutput(out_val); node_output = fork_node->output()->setType(FutureType::create(out_val->type())); // Lambda lift into a Subgraph attribute torch::jit::script::lambdaLiftFork(fork_node); } // Record the ivalue in the tracer jit::tracer::setValueTrace(retval, node_output); // stuff the ivalue output in the Future retval->markCompleted(output_ivalue); return PythonFutureWrapper(retval); } else { auto retval = c10::make_intrusive(); retval->markCompleted(toIValue(f(*args_tup))); return PythonFutureWrapper(retval); } }); m.def("wait", [](PythonFutureWrapper& fut) { if (jit::tracer::isTracing()) { auto graph = jit::tracer::getTracingState()->graph; Value* fut_val = jit::tracer::getValueTrace(fut.fut); auto output = graph->insert(aten::wait, {fut_val}); jit::tracer::setValueTrace(fut.fut->value(), output); } return fut.fut->value(); }); m.def("_jit_assert_is_instance", [](py::object obj, TypePtr type) { toIValue(obj, type); }); initPythonIRBindings(module); tracer::initPythonTracerBindings(module); script::initTreeViewBindings(module); script::initJitScriptBindings(module); } } // namespace jit } // namespace torch