pytorch/torch/csrc/jit/python/init.cpp
eellison d5df055bbb [WIP][JIT] Add JIT backend registration API (#35833)
Summary:
**Summary**
This commit adds `torch::jit::RegisterBackend`, an API that allows
external backends to be registered for the execution of JIT subgraphs
outside the JIT interpreter. In order to register an external backend,
one must extend the provided abstract class `PyTorchBackendInterface` and provide
two additional functions: one that creates an instance of the aforementioned subclass
of `PyTorchBackendInterface`, and another that preprocesses a `ScriptModule` so that
it can run on the backend. Then, a `ScriptModule` that can compile and execute a given
JIT subgraph using the functions provided at registration time is generated
for each registered backend.

**Testing**
This commit adds a unit test that uses a minimal test backend
to make sure that the registration endpoint and generated
`ScriptModule` work.

```
$ python test/test_jit.py TestBackends
Fail to import hypothesis in common_utils, tests are not derandomized
.
----------------------------------------------------------------------
Ran 1 test in 0.183s

OK

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35833

Differential Revision: D21231955

Pulled By: SplitInfinity

fbshipit-source-id: 452db1123d0e5d83f97fe5da8a00fdfdb50dbef9
2020-05-07 18:15:26 -07:00

923 lines
35 KiB
C++

#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/backends/backend_init.h>
#include <torch/csrc/jit/codegen/fuser/interface.h>
#include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/canonicalize_ops.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
#include <torch/csrc/jit/passes/create_functional_graphs.h>
#include <torch/csrc/jit/passes/cuda_graph_fuser.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/decompose_ops.h>
#include <torch/csrc/jit/passes/erase_number_types.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/fuse_linear.h>
#include <torch/csrc/jit/passes/graph_fuser.h>
#include <torch/csrc/jit/passes/inline_fork_wait.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/passes/loop_unrolling.h>
#include <torch/csrc/jit/passes/lower_graph.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
#include <torch/csrc/jit/passes/onnx.h>
#include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
#include <torch/csrc/jit/passes/onnx/constant_fold.h>
#include <torch/csrc/jit/passes/onnx/fixup_onnx_conditionals.h>
#include <torch/csrc/jit/passes/onnx/fixup_onnx_loop.h>
#include <torch/csrc/jit/passes/onnx/function_substitution.h>
#include <torch/csrc/jit/passes/onnx/peephole.h>
#include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
#include <torch/csrc/jit/passes/onnx/prepare_inplace_ops_for_onnx.h>
#include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
#include <torch/csrc/jit/passes/onnx/unpack_quantized_weights.h>
#include <torch/csrc/jit/passes/peephole.h>
#include <torch/csrc/jit/passes/quantization.h>
#include <torch/csrc/jit/passes/remove_expands.h>
#include <torch/csrc/jit/passes/remove_inplace_ops.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/specialize_autogradzero.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
#include <torch/csrc/jit/passes/xnnpack_rewrite.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/jit/python/python_arg_flatten.h>
#include <torch/csrc/jit/python/python_custom_class.h>
#include <torch/csrc/jit/python/python_ir.h>
#include <torch/csrc/jit/python/python_tracer.h>
#include <torch/csrc/jit/python/python_tree_views.h>
#include <torch/csrc/jit/python/script_init.h>
#include <torch/csrc/jit/runtime/argument_spec.h>
#include <torch/csrc/jit/runtime/autodiff.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/print_handler.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <c10/macros/Export.h>
#include <caffe2/serialize/inline_container.h>
#include <ATen/core/function_schema.h>
#include <pybind11/functional.h>
#include <pybind11/iostream.h>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <tuple>
#include <utility>
namespace torch {
namespace jit {
using ::c10::Argument;
using ::c10::FunctionSchema;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::PyTorchStreamWriter;
namespace {
using autograd::variable_list;
bool loadPythonClasses() {
// Leaving this code here, because it will likely be useful at some point
// PyObject *jit_module = PyImport_ImportModule("torch.jit");
// THPUtils_assert(jit_module, "class loader couldn't access "
//"torch.jit module");
// PyObject *jit_dict = PyModule_GetDict(jit_module);
return true;
}
} // anonymous namespace
#if !defined(__HIP_PLATFORM_HCC__)
TORCH_API void runJITCPPTests(bool runCuda);
TORCH_API void runTENSOREXPRCPPTests(bool runCuda);
#endif
void initJITBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::register_exception<JITException>(m, "JITException");
py::class_<python::IODescriptor> iodescriptor(
m, "IODescriptor"); // NOLINT(bugprone-unused-raii)
m.def("_jit_init", loadPythonClasses)
.def(
"_jit_debug_fuser_num_cached_kernel_specs",
torch::jit::fuser::debugNumCachedKernelSpecs)
.def("_jit_pass_onnx_remove_print", RemovePrintOps)
.def("_jit_pass_onnx_preprocess_caffe2", PreprocessCaffe2Ops)
.def("_jit_pass_onnx", ToONNX)
.def("_jit_pass_lower_all_tuples", LowerAllTuples)
.def("_jit_pass_onnx_function_substitution", ONNXFunctionCallSubstitution)
.def(
"_jit_pass_onnx_peephole",
[](std::shared_ptr<Graph>& graph,
int opset_version,
bool fixed_batch_size) {
return PeepholeOptimizeONNX(graph, opset_version, fixed_batch_size);
})
.def(
"_jit_pass_onnx_cast_all_constant_to_floating",
CastAllConstantToFloating)
.def(
"_jit_pass_onnx_constant_fold",
[](std::shared_ptr<Graph>& graph,
std::map<std::string, IValue>& paramsDict,
int opset_version) {
ConstantFoldONNX(
graph->block(),
paramsDict,
opset_version); // overload resolution
return paramsDict;
},
pybind11::return_value_policy::move)
.def("_jit_pass_onnx_scalar_type_analysis", ScalarTypeAnalysisForONNX)
.def(
"_jit_pass_onnx_prepare_inplace_ops_for_onnx",
PrepareInplaceOpsForONNX)
.def("_jit_pass_fuse", FuseGraph)
.def(
"_jit_pass_dce",
[](std::shared_ptr<Graph>& g) {
return EliminateDeadCode(g->block()); // overload resolution
})
.def(
"_jit_pass_dce_allow_deleting_nodes_with_side_effects",
[](std::shared_ptr<Graph>& g) {
return EliminateDeadCode(
g->block(),
true,
DCESideEffectPolicy::
ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); // overload
// resolution
})
.def(
"_jit_pass_cse",
[](std::shared_ptr<Graph>& g) {
return EliminateCommonSubexpression(g); // overload resolution
})
.def(
"_jit_pass_insert_observers",
[](Module& module,
const std::string& method_name,
const py::dict& qconfig_dict,
bool inplace,
bool is_dynamic) {
auto dict = py::cast<std::unordered_map<
std::string,
c10::optional<std::tuple<Module, Module>>>>(qconfig_dict);
return InsertObservers(
module, method_name, dict, inplace, is_dynamic);
},
py::arg("module"),
py::arg("method_name"),
py::arg("qconfig_dict"),
py::arg("inplace") = false,
py::arg("is_dynamic") = false)
.def(
"_jit_pass_insert_quant_dequant",
[](Module& module,
const std::string& method_name,
bool inplace,
bool is_dynamic) {
return InsertQuantDeQuant(module, method_name, inplace, is_dynamic);
},
py::arg("module"),
py::arg("method_name"),
py::arg("inplace") = false,
py::arg("is_dynamic") = false)
.def(
"_jit_pass_insert_prepack_unpack",
[](std::shared_ptr<Graph>& g) { return InsertPrepackUnpack(g); })
.def(
"_jit_pass_insert_prepack_unpack",
[](Module& module) { return InsertPrepackUnpack(module); })
.def(
"_jit_pass_quant_fusion",
[](std::shared_ptr<Graph>& g) { return QuantFusion(g); })
.def("_jit_pass_fold_convbn", &FoldConvBatchNorm2d)
.def(
"_freeze_module",
[](Module& module) { return freeze_module(module); },
py::arg("module"))
.def("_jit_pass_fuse_linear", &FuseLinear)
.def(
"_jit_pass_fold_quantize",
[](Module& module, const std::string& method_name) {
FoldQuantizeCallIntoBuffer(module, method_name);
})
.def("_jit_pass_fold_prepack", &FoldPrepackedWeightIntoModule)
.def("_jit_pass_dedup_module_uses", &DedupModuleUses)
.def("_jit_pass_replicate_dequantize", &ReplicateDeQuant)
.def("_jit_pass_swap_dequantize", &PropagateQuantizationOps)
.def(
"_jit_pass_swap_functional_linear",
[](std::shared_ptr<Graph>& graph) { SwapFunctionalLinear(graph); })
.def(
"_jit_pass_swap_functional_linear",
[](Module& module) { SwapFunctionalLinear(module); })
.def(
"_jit_pass_quant_finalize",
[](Module& module, bool is_dynamic) {
return Finalize(module, is_dynamic);
},
py::arg("module"),
py::arg("is_dynamic") = false)
.def(
"_jit_pass_pattern_based_rewrite",
[](const Module& m) { return PatternBasedRewrite(m); })
.def(
"_jit_pass_custom_pattern_based_rewrite",
[](const std::string& pattern,
const std::string& fused_node_name,
const Module& m) {
SubgraphRewriter subgraph_rewriter;
subgraph_rewriter.RegisterRewritePattern(pattern, fused_node_name);
subgraph_rewriter.runOnModule(m);
})
.def(
"_jit_pass_custom_pattern_based_rewrite_graph",
[](const std::string& pattern,
const std::string& fused_node_name,
std::shared_ptr<Graph> g) {
SubgraphRewriter subgraph_rewriter;
subgraph_rewriter.RegisterRewritePattern(pattern, fused_node_name);
subgraph_rewriter.runOnGraph(g);
})
.def(
"_jit_pass_fold_quant_inputs",
[](std::shared_ptr<Graph>& g) {
return FoldQuantNodesIntoInputsOutputs(g);
})
.def(
"_jit_pass_remove_inplace_ops",
[](std::shared_ptr<Graph> g) { return RemoveInplaceOps(g); })
.def("_jit_pass_constant_pooling", ConstantPooling)
.def(
"_jit_pass_create_functional_graphs",
[](std::shared_ptr<Graph>& g) { return CreateFunctionalGraphs(g); })
.def(
"_jit_pass_remove_mutation",
[](std::shared_ptr<Graph>& g) {
RemoveListMutation(g);
return RemoveTensorMutation(g);
})
.def(
"_jit_pass_inline_functional_graphs",
[](std::shared_ptr<Graph>& g) { return InlineFunctionalGraphs(g); })
.def(
"_jit_pass_peephole",
[](const std::shared_ptr<Graph>& g, bool addmm_fusion_enabled) {
return PeepholeOptimize(g, addmm_fusion_enabled);
},
py::arg("graph"),
py::arg("addmm_fusion_enabled") = false)
.def(
"_jit_pass_fuse_addmm",
[](std::shared_ptr<Graph>& g) { return FuseAddMM(g); })
.def(
"_jit_pass_canonicalize",
[](const std::shared_ptr<Graph>& g) { return Canonicalize(g); })
.def("_jit_pass_lint", LintGraph)
.def(
"_jit_pass_complete_shape_analysis",
[](std::shared_ptr<Graph> graph, py::tuple inputs, bool with_grad) {
ArgumentSpecCreator arg_spec_creator(*graph);
Stack stack;
stack.reserve(inputs.size()); // captures?
for (auto& obj : inputs) {
stack.push_back(toTypeInferredIValue(obj));
}
ArgumentSpec spec = arg_spec_creator.create(with_grad, stack);
arg_spec_creator.specializeTypes(*graph, spec);
// We only get partial specialization from the arg_spec_creator, but
// we want full shape specialization. The alternative would be to
// have a "complete type inference" function in ArguemntSpecCreator.
auto g_inputs = graph->inputs();
for (size_t i = 0; i < inputs.size(); ++i) {
if (stack[i].isTensor()) {
g_inputs[i]->setType(stack[i].type());
}
}
PropagateInputShapes(graph);
})
.def("_jit_pass_remove_expands", RemoveExpands)
.def("_jit_pass_erase_number_types", EraseNumberTypes)
.def("_jit_pass_inline_fork_wait", InlineForkWait)
.def("_jit_pass_inline", Inline)
.def("_jit_pass_prepare_division_for_onnx", PrepareDivisionForONNX)
.def(
"_jit_pass_lower_graph",
[](std::shared_ptr<Graph>& graph, const Module& self) {
return LowerGraph(*graph, self._ivalue());
})
.def("_jit_pass_loop_unrolling", UnrollLoops)
.def(
"_jit_pass_constant_propagation",
[](std::shared_ptr<Graph>& g) { return ConstantPropagation(g); })
.def("_jit_pass_erase_shape_information", EraseShapeInformation)
.def(
"_jit_pass_create_autodiff_subgraphs",
[](std::shared_ptr<Graph> graph) { CreateAutodiffSubgraphs(graph); })
#if defined(BUILDING_TESTS) && !defined(__HIP_PLATFORM_HCC__)
.def(
"_jit_run_cpp_tests",
[](bool runCuda) {
// We have to release the GIL inside this method, because if we
// happen to initialize the autograd engine in these tests, the
// newly spawned worker threads will try to initialize their
// PyThreadState*, and they need the GIL for this.
pybind11::gil_scoped_release _no_gil;
return runJITCPPTests(runCuda);
},
py::arg("run_cuda"))
.def("_jit_has_cpp_tests", []() { return true; })
.def(
"_run_tensorexpr_cpp_tests",
[](bool runCuda) {
// We have to release the GIL inside this method, because if we
// happen to initialize the autograd engine in these tests, the
// newly spawned worker threads will try to initialize their
// PyThreadState*, and they need the GIL for this.
pybind11::gil_scoped_release _no_gil;
return runTENSOREXPRCPPTests(runCuda);
},
py::arg("run_cuda"))
.def("_has_tensorexpr_cpp_tests", []() { return true; })
#else
.def("_jit_run_cpp_tests", []() { throw std::exception(); })
.def("_jit_has_cpp_tests", []() { return false; })
.def("_run_tensorexpr_cpp_tests", []() { throw std::exception(); })
.def("_has_tensorexpr_cpp_tests", []() { return false; })
#endif
.def(
"_jit_flatten",
[](py::handle& obj) {
auto res = python::flatten(obj);
return std::make_pair(res.vars, res.desc);
})
.def(
"_jit_unflatten",
[](autograd::variable_list vars, python::IODescriptor& desc) {
return py::reinterpret_steal<py::object>(
python::unflatten(vars, desc));
})
.def("_jit_pass_onnx_block", BlockToONNX)
.def("_jit_pass_fixup_onnx_loops", FixupONNXLoops)
.def("_jit_pass_fixup_onnx_conditionals", FixupONNXConditionals)
.def("_jit_pass_canonicalize_ops", CanonicalizeOps)
.def("_jit_pass_decompose_ops", DecomposeOps)
.def("_jit_pass_specialize_autogradzero", specializeAutogradZero)
.def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
.def("_jit_override_can_fuse_on_gpu", &overrideCanFuseOnGPU)
.def("_jit_can_fuse_on_cpu", &canFuseOnCPU)
.def("_jit_can_fuse_on_gpu", &canFuseOnGPU)
.def(
"_jit_differentiate",
[](Graph& g) {
// the python binding slightly differs in semantics
// it makes a copy of the input Graph, and works on that
// jit::differentiate mutates the input Graph
auto g_clone = g.copy();
return differentiate(g_clone);
})
.def(
"_jit_check_alias_annotation",
[](std::shared_ptr<Graph> g,
py::tuple args,
const std::string& unqualified_op_name) {
auto stack = toTraceableStack(args);
checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
})
.def("_jit_register_cuda_fuser", &RegisterCudaFuseGraph::registerPass)
.def("_jit_clear_cuda_fuser", &RegisterCudaFuseGraph::clearPass)
.def(
"_jit_set_profiling_mode",
[](bool profiling_flag) {
bool oldState = getProfilingMode();
getProfilingMode() = profiling_flag;
return oldState;
})
.def(
"_jit_set_profiling_executor",
[](bool profiling_flag) {
bool oldState = getExecutorMode();
getExecutorMode() = profiling_flag;
return oldState;
})
.def(
"_jit_set_num_profiled_runs",
[](size_t num) {
size_t old_num = getNumProfiledRuns();
getNumProfiledRuns() = num;
return old_num;
})
.def(
"_jit_set_bailout_depth",
[](size_t depth) {
size_t old_depth = getBailoutDepth();
getBailoutDepth() = depth;
return old_depth;
})
.def(
"_jit_set_inline_everything_mode",
[](bool enabled) { getInlineEverythingMode() = enabled; })
.def(
"_jit_get_inline_everything_mode",
[]() { return getInlineEverythingMode(); })
.def(
"_jit_try_infer_type",
[](py::object obj) -> TypePtr {
auto match = tryToInferType(obj);
if (match.success()) {
return match.type();
}
return nullptr;
})
.def(
"_jit_get_trigger_value",
[](const std::string& trigger_name) {
using namespace torch::jit::tensorexpr;
ExecutionTrigger* trigger =
ExecutionTriggerList::GetInstance().FindByName(trigger_name);
return trigger->value();
})
.def(
"_jit_get_te_cuda_pointwise_loop_levels",
[]() -> int {
using namespace torch::jit::tensorexpr;
return getTECudaPointwiseLoopLevels();
})
.def(
"_jit_set_te_cuda_pointwise_loop_levels",
[](int level) {
using namespace torch::jit::tensorexpr;
return getTECudaPointwiseLoopLevels() = level;
})
.def(
"_jit_get_te_cuda_pointwise_block_count",
[]() -> int {
using namespace torch::jit::tensorexpr;
return getTECudaPointwiseBlockCount();
})
.def(
"_jit_set_te_cuda_pointwise_block_count",
[](int block_count) {
using namespace torch::jit::tensorexpr;
return getTECudaPointwiseBlockCount() = block_count;
})
.def(
"_jit_get_te_cuda_pointwise_block_size",
[]() -> int {
using namespace torch::jit::tensorexpr;
return getTECudaPointwiseBlockSize();
})
.def(
"_jit_set_te_cuda_pointwise_block_size",
[](int block_size) {
using namespace torch::jit::tensorexpr;
return getTECudaPointwiseBlockSize() = block_size;
})
.def("_jit_set_texpr_fuser_enabled", &setTensorExprFuserEnabled)
.def("_jit_texpr_fuser_enabled", &tensorExprFuserEnabled)
.def("_jit_texpr_fallback_allowed", &tensorexpr::fallbackAllowed)
.def("_jit_texpr_set_fallback_allowed", &tensorexpr::setFallbackAllowed)
.def(
"_jit_pass_fuse_tensorexprs",
[](std::shared_ptr<Graph>& g) { return FuseTensorExprs(g); })
.def(
"_jit_fuser_get_fused_kernel_code",
[](Graph& g, std::vector<at::Tensor> inps) {
return debugGetFusedKernelCode(g, inps);
})
.def(
"_jit_pass_insert_prepacked_ops",
[](std::shared_ptr<Graph>& graph) {
return insertPrePackedOps(graph);
})
.def(
"_jit_pass_insert_prepacked_ops",
[](script::Module& module) { return insertPrePackedOps(module); })
.def(
"_jit_pass_fuse_clamp_w_prepacked_linear_conv",
[](script::Module& module) {
return fusePrePackedLinearConvWithClamp(module);
})
.def(
"_jit_pass_fold_prepacking_ops",
[](script::Module& module) { return FoldPrePackingOps(module); })
.def(
"_jit_pass_optimize_for_mobile",
[](script::Module& module) { return optimizeForMobile(module); })
.def(
"_jit_pass_onnx_unpack_quantized_weights",
[](std::shared_ptr<Graph>& graph,
std::map<std::string, IValue>& paramsDict) {
UnpackQuantizedWeights(graph, paramsDict);
return paramsDict;
},
pybind11::return_value_policy::move)
.def(
"_jit_pass_onnx_quantization_insert_permutes",
[](std::shared_ptr<Graph>& graph,
std::map<std::string, IValue>& paramsDict) {
insertPermutes(graph, paramsDict);
return paramsDict;
},
pybind11::return_value_policy::move)
.def(
"_jit_pass_filter_non_tensor_arguments",
[](std::map<std::string, IValue> params) {
std::map<std::string, at::Tensor> retval;
for (auto& kv : params) {
if (kv.second.isTensor()) {
retval[kv.first] = std::move(kv.second).toTensor();
}
}
return retval;
})
.def("_jit_decay_packed_param_input_types", [](Graph& g) {
for (Value* i : g.inputs()) {
if (i->type() ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase") ||
i->type() ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv3dPackedParamsBase")) {
// Dummy CompleteTensorType to appease ONNX validator.
i->setType(TensorType::create(
at::kQInt8,
c10::kCPU,
std::vector<int64_t>{1},
std::vector<int64_t>{1},
c10::nullopt));
}
}
});
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<CompleteArgumentSpec>(m, "CompleteArgumentSpec")
.def("__repr__", [](CompleteArgumentSpec& self) {
std::ostringstream s;
s << self;
return s.str();
});
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<ArgumentSpec>(m, "ArgumentSpec");
py::class_<Code>(m, "Code")
.def(
"grad_executor_states",
[](Code& c) {
std::vector<GraphExecutorState> states;
for (auto& e : c.grad_executors()) {
states.emplace_back(e->getDebugState());
}
return states;
})
.def("num_bailouts", [](Code& c) { return c.num_bailouts(); })
.def("request_bailout", [](Code& c, size_t index) {
c.request_bailout(index);
});
py::class_<ExecutionPlan>(m, "ExecutionPlan")
.def_property_readonly("graph", [](ExecutionPlan& s) { return s.graph; })
.def_property_readonly("code", [](ExecutionPlan& s) { return s.code; });
py::class_<Gradient>(m, "Gradient")
.def_property_readonly("f", [](Gradient& m) { return m.f; })
.def_property_readonly("df", [](Gradient& m) { return m.df; })
.def_property_readonly(
"f_real_outputs", [](Gradient& m) { return m.f_real_outputs; })
.def_property_readonly(
"df_input_vjps", [](Gradient& m) { return m.df_input_vjps; })
.def_property_readonly(
"df_input_captured_inputs",
[](Gradient& m) { return m.df_input_captured_inputs; })
.def_property_readonly(
"df_input_captured_outputs",
[](Gradient& m) { return m.df_input_captured_outputs; })
.def_property_readonly(
"df_output_vjps", [](Gradient& m) { return m.df_output_vjps; });
py::class_<GraphExecutorState>(m, "GraphExecutorState")
.def_property_readonly(
"graph", [](GraphExecutorState& s) { return s.graph; })
.def_property_readonly(
"execution_plans",
[](GraphExecutorState& s) { return s.execution_plans; })
.def_property_readonly(
"fallback", [](GraphExecutorState& s) { return s.fallback; });
py::class_<PyTorchStreamWriter>(m, "PyTorchFileWriter")
.def(py::init<std::string>())
.def(py::init([](const py::object& buffer) {
auto writer_func = [=](const void* data, size_t size) {
auto bytes = py::bytes(reinterpret_cast<const char*>(data), size);
buffer.attr("write")(std::move(bytes));
return size;
};
return std::make_unique<PyTorchStreamWriter>(std::move(writer_func));
}))
.def(py::init<const std::function<size_t(const void*, size_t)>&>())
.def(
"write_record",
[](PyTorchStreamWriter& self,
const std::string& name,
const char* data,
size_t size) { return self.writeRecord(name, data, size); })
.def("write_end_of_file", &PyTorchStreamWriter::writeEndOfFile)
.def(
"write_record",
[](PyTorchStreamWriter& self,
const std::string& name,
uintptr_t data,
size_t size) {
return self.writeRecord(
name, reinterpret_cast<const char*>(data), size);
});
// This allows PyTorchStreamReader to read from a Python buffer. It requires
// that the buffer implement `seek()`, `tell()`, and `read()`.
class BufferAdapter : public caffe2::serialize::ReadAdapterInterface {
public:
BufferAdapter(const py::object& buffer) : buffer_(buffer) {
// Jump to the end of the buffer to get its size
auto current = buffer.attr("tell")();
start_offset_ = py::cast<size_t>(current);
buffer.attr("seek")(current, py::module::import("os").attr("SEEK_END"));
size_ = py::cast<size_t>(buffer.attr("tell")()) - start_offset_;
buffer.attr("seek")(current);
// If we can read directly into a buffer, do that instead of an extra copy
use_readinto_ = py::hasattr(buffer, "readinto");
}
size_t size() const override {
return size_;
}
THPObjectPtr getMemview(void* buf, size_t n) const {
#if PY_MAJOR_VERSION >= 3
THPObjectPtr memview(PyMemoryView_FromMemory(
reinterpret_cast<char*>(buf), n, PyBUF_WRITE));
#else
THPObjectPtr memview(PyBuffer_FromReadWriteMemory(buf, n));
#endif
if (!memview) {
throw python_error();
}
return memview;
}
size_t read(uint64_t pos, void* buf, size_t n, const char* what)
const override {
// Seek to desired position (NB: this has to be a Py_ssize_t or Python
// throws a weird error)
Py_ssize_t absolute_pos = start_offset_ + pos;
buffer_.attr("seek")(absolute_pos);
if (use_readinto_) {
auto memview = getMemview(buf, n);
auto res =
PyObject_CallMethod(buffer_.ptr(), "readinto", "O", memview.get());
if (res) {
int i = PyInt_AsLong(res);
if (i > 0) {
return i;
}
}
}
// Read bytes into `buf` from the buffer
std::string bytes = py::cast<std::string>(buffer_.attr("read")(n));
std::copy(
bytes.data(),
bytes.data() + bytes.size(),
reinterpret_cast<char*>(buf));
return bytes.size();
}
py::object buffer_;
size_t size_;
size_t start_offset_;
bool use_readinto_;
};
py::class_<PyTorchStreamReader>(m, "PyTorchFileReader")
.def(py::init<std::string>())
.def(py::init([](const py::object& buffer) {
auto adapter = std::make_unique<BufferAdapter>(std::move(buffer));
return std::make_unique<PyTorchStreamReader>(std::move(adapter));
}))
.def(
"get_record",
[](PyTorchStreamReader& self, const std::string& key) {
at::DataPtr data;
size_t size;
std::tie(data, size) = self.getRecord(key);
return py::bytes(reinterpret_cast<const char*>(data.get()), size);
})
.def("get_all_records", [](PyTorchStreamReader& self) {
return self.getAllRecords();
});
m.def(
"_jit_get_operation",
[](const std::string& op_name) {
try {
auto symbol = Symbol::fromQualString(op_name);
auto operations = getAllOperatorsFor(symbol);
TORCH_CHECK(!operations.empty(), "No such operator ", op_name);
std::ostringstream docstring;
docstring << "Automatically bound operator '" << op_name
<< "' with schema(s):\n";
for (const auto& op : operations) {
docstring << " " << op->schema() << "\n";
}
return py::cpp_function(
[operations](py::args args, py::kwargs kwargs) {
return invokeOperatorFromPython(
operations, std::move(args), std::move(kwargs));
},
py::name(symbol.toUnqualString()),
py::doc(docstring.str().c_str()));
} catch (const c10::Error& error) {
throw std::runtime_error(error.what_without_backtrace());
}
},
py::arg("qualified_name"));
m.def("parse_ir", [](const std::string& input) {
auto graph = std::make_shared<Graph>();
parseIR(input, &*graph);
return graph;
});
m.def("parse_schema", parseSchema);
py::class_<FunctionSchema>(m, "FunctionSchema")
.def_property_readonly(
"name", [](FunctionSchema& self) { return self.name(); })
.def_property_readonly(
"overload_name",
[](FunctionSchema& self) { return self.overload_name(); })
.def_property_readonly(
"arguments", [](FunctionSchema& self) { return self.arguments(); })
.def_property_readonly(
"returns", [](FunctionSchema& self) { return self.returns(); })
.def(
"is_backward_compatible_with",
[](const FunctionSchema& self, const FunctionSchema& old_schema) {
return self.isBackwardCompatibleWith(old_schema);
})
.def(
"__eq__",
[](const FunctionSchema& self, const FunctionSchema& other) {
return self == other;
})
.def("__str__", [](FunctionSchema& self) {
std::stringstream ss;
ss << self;
return ss.str();
});
py::class_<Argument>(m, "Argument")
.def_property_readonly("name", [](Argument& self) { return self.name(); })
.def_property_readonly("type", [](Argument& self) { return self.type(); })
.def_property_readonly(
"N",
[](Argument& self) -> py::object {
return (self.N()) ? py::cast(*self.N()) : py::none();
})
.def_property_readonly("default_value", [](Argument& self) -> py::object {
if (!self.default_value())
return py::none();
IValue v = *self.default_value();
return toPyObject(std::move(v));
});
m.def("_jit_get_all_schemas", []() {
const std::vector<std::shared_ptr<Operator>>& operations =
getAllOperators();
return fmap(operations, [](const std::shared_ptr<Operator>& op) {
return op->schema();
});
});
m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) {
auto symbol = Symbol::fromQualString(qualified_name);
auto operations = getAllOperatorsFor(symbol);
return fmap(operations, [](const std::shared_ptr<Operator>& op) {
return op->schema();
});
});
m.def("_is_tracing", []() { return jit::tracer::isTracing(); });
py::class_<PythonFutureWrapper>(m, "Future")
.def(
"wait",
[](PythonFutureWrapper& fut) {
auto res = fut.wait();
{
// acquiring GIL as toPyObject creates new py::object
// without grabbing the GIL.
pybind11::gil_scoped_acquire ag;
return toPyObject(std::move(res));
}
},
py::call_guard<py::gil_scoped_release>());
m.def("fork", [](py::args args) {
AT_ASSERT(args.size() >= 1);
py::function f = py::cast<py::function>(args[0]);
py::tuple args_tup(args.size() - 1);
for (size_t i = 1; i < args.size(); ++i) {
args_tup[i - 1] = args[i];
}
if (jit::tracer::isTracing()) {
auto graph = jit::tracer::getTracingState()->graph;
auto fork_node = graph->insertNode(graph->create(prim::TracedFork, 1));
auto body_block = fork_node->addBlock();
Value* node_output;
py::object py_func_output;
// Insert new trace ops into the fork op's sub-block
WithInsertPoint guard(body_block);
IValue output_ivalue;
{
tracer::WithNestedTracingFrame env_guard;
// Run the user-supplied function
py_func_output = f(*args_tup);
// Convert the output of the user-supplied function to IValue. The type
// information of this IValue is used both to record the correct type in
// the trace.
output_ivalue = toTypeInferredIValue(py_func_output);
Value* out_val = jit::tracer::getValueTrace(output_ivalue);
body_block->registerOutput(out_val);
node_output =
fork_node->output()->setType(FutureType::create(out_val->type()));
}
auto retval =
c10::make_intrusive<c10::ivalue::Future>(output_ivalue.type());
// Record the ivalue in the tracer
jit::tracer::setValueTrace(retval, node_output);
// stuff the ivalue output in the Future
retval->markCompleted(output_ivalue);
return PythonFutureWrapper(retval);
} else {
auto result = toTypeInferredIValue(f(*args_tup));
auto retval = c10::make_intrusive<c10::ivalue::Future>(result.type());
retval->markCompleted(std::move(result));
return PythonFutureWrapper(retval);
}
});
m.def("wait", [](PythonFutureWrapper& fut) { return fut.wait(); });
m.def("_jit_assert_is_instance", [](py::object obj, TypePtr type) {
toIValue(obj, type);
});
initPythonCustomClassBindings(module);
initPythonIRBindings(module);
tracer::initPythonTracerBindings(module);
initTreeViewBindings(module);
initJitScriptBindings(module);
initJitBackendBindings(module);
setPrintHandler([](const std::string& str) {
py::gil_scoped_acquire acquire;
try {
auto _stdout = py::module::import("sys").attr("stdout");
_stdout.attr("write")(str);
} catch (py::error_already_set& e) {
throw std::runtime_error(e.what());
}
});
}
} // namespace jit
} // namespace torch