pytorch/torch/csrc/jit/backends/backend_init.cpp
Kimish Patel 2ce21b2e61 [Pytorch backend delegation] Preprocess to accept (#58873)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58873

BackenDebugInforRecorder

Prior to this PR:
In order to generate debug handles corresponding to the graph being
lowered, backend's preprocess will call generate_debug_handles and will
get map of Node*-to-debug_handles.
In order to facilitate this, to_backend will own
BackendDebugInfoRecorder and initialize thread local pointer to it.
generate_debug_handle function will query thread local pointer to see if
there is a valid BackendDebugInforRecorder for the context. If there is
it will generate debug handles.

After this PR:
Signature of preprocess is changed such that backends have to register
preprocess that accepts instance of BackendDebugInfoRecorder by
reference. generate_debug_handles is no more a free function but becomes
part of the API of BackendDebugInfoRecorder. Now backend's preprocess
function will call generate_debug_handles on BackendDebugInfoRecorder
instead of free function.

Reason for this change:
With RAII that initializes thread local pointer, results in a lose
contract with backends, which may result in backends not storing
debug information. Making it part of API results in
backends having to be aware of BackendDebugInfoRecorder and explicitly
chosing not to generate/store debug information if they chose to do so.

Test Plan:
backend tests

Imported from OSS

Reviewed By: jbschlosser, raziel

Differential Revision: D28648613

fbshipit-source-id: c9b7e7bf0f78e87023ea7bc08612cf893b08cb98
2021-06-11 10:16:00 -07:00

194 lines
7.3 KiB
C++

#include <torch/csrc/jit/backends/backend_init.h>
#include <pybind11/iostream.h>
#include <torch/csrc/jit/backends/backend_detail.h>
#include <torch/csrc/jit/backends/backend_resolver.h>
#include <torch/csrc/jit/frontend/code_template.h>
#include <torch/csrc/jit/python/module_python.h>
#include <torch/csrc/jit/python/pybind_utils.h>
namespace torch {
namespace jit {
// Get all types that are shared in the module hierarchy rooted at \p mod.
std::unordered_set<TypePtr> getSharedModuleTypes(Module& mod) {
// Maintain a set of all TypePtrs.
std::unordered_set<TypePtr> types;
// Maintain another set of TypePtrs that have been encountered more than once.
std::unordered_set<TypePtr> duplicate_types;
// Iterate over all modules in the hierarchy, including the root.
for (auto module : mod.modules()) {
auto module_type = module.type();
if (types.count(module_type) > 0) {
duplicate_types.insert(module_type);
}
types.insert(module_type);
}
return duplicate_types;
}
// Selectively lower \p mod to a backend. \p to_backend
// is called to lower modules. \p modules_to_lower contains
// qualified names of submodules of \p mod that should be lowered.
void toBackendSelectiveImpl(
Module& mod,
const py::function& to_backend,
const std::vector<std::string>& modules_to_lower,
const std::unordered_set<TypePtr>& duplicate_types) {
// This map will be used later to remap types in ancestor module graphs for
// all lowered submodules.
std::unordered_map<TypePtr, TypePtr> type_remap;
// For each module that should be lowered:
for (const auto& module_to_lower : modules_to_lower) {
// Use QualifiedName to parse the qualified module names.
c10::QualifiedName qual_module_name(module_to_lower);
auto& atoms = qual_module_name.atoms();
// Search through the module hierarchy using the atoms of
// qual_module_name until current points to the module to
// be lowered and parent points to its parent.
Module current = mod;
Module parent;
for (size_t i = 0, e = atoms.size(); i < e; ++i) {
IValue submodule = current.attr(atoms[i]);
if (submodule.isModule()) {
if (i == e - 1) {
parent = current;
}
current = submodule.toModule();
} else {
std::stringstream err;
err << "Attribute named " << atoms[i] << " is not a Module";
throw std::runtime_error(err.str());
}
}
// Check that the parent type is not shared and therefore can be edited.
if (duplicate_types.count(parent.type()) > 0) {
throw py::cast_error(c10::str(
"Selective lowering is only supported for module hierarchies with unique types for selected modules; ",
parent.type()->repr_str(),
" is shared"));
}
// Call to_backend on the module that needs to be lowered. It needs to be
// wrapped before doing so because _to_jit_backend accepts wrapped modules.
// The result needs to be unwrapped in order to access its type below.
auto lowered_submodule =
py::cast<Module>(to_backend(py::module::import("torch.jit._recursive")
.attr("wrap_cpp_module")(current))
.attr("_c"));
// Adjust the parent's type so that the type of the submodule matches
// the type of lowered_submodule.
auto parent_type = parent.type();
parent_type->unsafeChangeAttributeType(
atoms.back(), lowered_submodule.type());
parent.setattr(atoms.back(), lowered_submodule._ivalue());
// Record the type mapping from old type -> lowered type.
type_remap[current.type()] = lowered_submodule.type();
}
// Having lowered all of the modules that needed to be lowered, remap types in
// all graphs in the hierarchy so that the graphs all use the new lowered
// type.
auto type_remap_fn = [&type_remap](TypePtr in) {
auto it = type_remap.find(in);
if (it == type_remap.end())
return in;
return it->second;
};
// modules() iterates over all modules in the hierarchy including the root.
for (auto module : mod.modules()) {
auto module_type = module.type();
for (auto& fn : module_type->methods()) {
auto method = module.get_method(fn->name());
auto graph = method.graph();
graph->remapTypes(type_remap_fn);
auto new_schema = fn->getSchema().cloneWithRemappedTypes(type_remap_fn);
fn->setSchema(new_schema);
}
}
}
Module codegen_func(
const std::string& backend_name,
const Module& orig_module,
const py::dict& method_compile_spec) {
// Represents of a Type of Dict[str, Any].
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
return detail::codegen_backend_module(
backend_name,
orig_module,
toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
any_dict_ty);
}
void initJitBackendBindings(PyObject* module) {
// Bind a function for lowering to each JIT backend. The name of the backend
// must be the first argument. For example, to lower a Module to
// "example_backend", declared as
//
// static auto cls = torch::jit::backend<ExampleBackend>("example_backend");
//
// this function must be called like
//
// torch._C._jit_to_backend("example_backend", module, spec)
auto m = py::handle(module).cast<py::module>();
m.def(
"_jit_to_backend",
[=](const std::string& backend_name,
py::handle orig_module,
const py::dict& method_compile_spec) {
py::scoped_ostream_redirect cerr(
std::cerr, py::module_::import("sys").attr("stderr"));
py::scoped_ostream_redirect cout(
std::cout, py::module_::import("sys").attr("stdout"));
return py::module::import("torch.jit._recursive")
.attr("wrap_cpp_module")(codegen_func(
backend_name,
py::cast<Module>(orig_module.attr("_c")),
method_compile_spec));
});
m.def(
"_jit_to_backend_selective",
[=](py::handle orig_module,
const py::function& to_backend,
const std::vector<std::string>& modules_to_lower) {
py::scoped_ostream_redirect cerr(
std::cerr, py::module_::import("sys").attr("stderr"));
py::scoped_ostream_redirect cout(
std::cout, py::module_::import("sys").attr("stdout"));
if (auto original_module =
as_module(py::cast<py::object>(orig_module))) {
// Clone the Module to avoid editing types that are shared with
// Modules in other instances outside this hierarchy.
Module& mod = original_module.value();
auto cloned_mod = mod.clone();
// Get all shared module types. Type sharing is only a problem if the
// parent modules of the ones to lower are in this set.
auto shared_types = getSharedModuleTypes(cloned_mod);
toBackendSelectiveImpl(
cloned_mod, to_backend, modules_to_lower, shared_types);
// Wrap the result in a RecursiveScriptModule because that's what
// the caller passed in.
return py::module::import("torch.jit._recursive")
.attr("wrap_cpp_module")(cloned_mod);
}
throw py::cast_error(c10::str(
"Object ", py::str(orig_module), " is not a ScriptModule"));
});
}
} // namespace jit
} // namespace torch