#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace py = pybind11; namespace torch { namespace jit { namespace { // Note: const_cast is used twice below to acquire a handle to a pyobject. Operation createPythonOperation(const Node* op_) { AutoGIL gil; const PythonOp* op = static_cast(op_); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) const py::function func = py::reinterpret_borrow(py::handle(const_cast(op)->pyobj.get())); size_t num_inputs = 0; for(auto arg_type : op->cconv) { if(arg_type == 'd') num_inputs++; } JIT_ASSERT(op->outputs().size() == 1); return [=](Stack & stack) { AutoGIL gil; py::tuple py_inputs(op->cconv.size()); size_t i = 0; size_t next_scalar = 0; size_t next_tensor = 0; for (auto arg_type : op->cconv) { if (arg_type == 'c') { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) py_inputs[i] = py::reinterpret_borrow(const_cast(op)->scalar_args[next_scalar++].get()); } else if (arg_type == 'd') { py_inputs[i] = toPyObject(std::move(peek(stack, next_tensor, num_inputs))); next_tensor++; } i++; } drop(stack, num_inputs); try { py::object py_output(func(*py_inputs)); stack.push_back(returnToIValue(op->output()->type(), py_output)); } catch (py::error_already_set & e) { throw std::runtime_error(e.what()); } return 0; }; } RegisterOperators reg({ Operator(prim::PythonOp, createPythonOperation) }); }}} // torch::jit::anon