mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/glow/pull/5029 Support single element tuples in to_backend Test Plan: new unit test for to_glow Reviewed By: andrewmillspaugh Differential Revision: D24539869 fbshipit-source-id: fb385a7448167b2b948e70f6af081bcf78f338dc
250 lines
9.3 KiB
C++
250 lines
9.3 KiB
C++
#include <torch/csrc/jit/backends/backend_init.h>
|
|
#include <torch/csrc/jit/backends/backend_detail.h>
|
|
#include <torch/csrc/jit/backends/backend_resolver.h>
|
|
#include <torch/csrc/jit/frontend/code_template.h>
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
void initJitBackendBindings(PyObject* module) {
|
|
// Bind a function for lowering to each JIT backend. The name of the backend
|
|
// must be the first argument. For example, to lower a Module to
|
|
// "example_backend", declared as
|
|
//
|
|
// static auto cls = torch::jit::backend<ExampleBackend>("example_backend");
|
|
//
|
|
// this function must be called like
|
|
//
|
|
// torch._C._jit_to_backend("example_backend", module, spec)
|
|
auto codegen_lambda = [=](const std::string& backend_name,
|
|
const Module& orig_module,
|
|
const py::dict& method_compile_spec) {
|
|
const c10::QualifiedName qual_backend_name({"__torch__",
|
|
"torch",
|
|
"classes",
|
|
detail::kBackendsNamespace,
|
|
backend_name});
|
|
// TODO: Validate method_compile_spec.
|
|
|
|
// Clone orig_module to make sure backend transformation is
|
|
// functional.
|
|
auto cloned_module = orig_module.clone();
|
|
|
|
// Represents of a Type of Dict[str, Any].
|
|
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
|
|
|
|
// Generate LoweredModule.
|
|
Module loweredModule(
|
|
"torch.jit." + backend_name + "LoweredModule",
|
|
get_python_cu(),
|
|
/*shouldMangle=*/true);
|
|
|
|
// Generate attributes.
|
|
// This is the original cloned and preprocessed module.
|
|
loweredModule.register_attribute(
|
|
"__processed_module",
|
|
AnyType::get(),
|
|
cloned_module._ivalue(),
|
|
/*is_param=*/false);
|
|
|
|
// This is for the method_compile_spec passed in to to_<backend> or
|
|
// loaded from an exported model.
|
|
loweredModule.register_attribute(
|
|
"__method_compile_spec",
|
|
any_dict_ty,
|
|
toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
|
|
/*is_param=*/false);
|
|
|
|
// This is a pointer to a backend instance that is used to access
|
|
// compile and execute functions.
|
|
auto cls = getCustomClass(qual_backend_name.qualifiedName());
|
|
TORCH_INTERNAL_ASSERT(cls);
|
|
c10::intrusive_ptr<torch::CustomClassHolder> backend;
|
|
loweredModule.register_attribute(
|
|
"__backend", cls, IValue::make_capsule(backend));
|
|
|
|
// This is the list of opaque backend handles returned by
|
|
// backend.compile.
|
|
loweredModule.register_attribute(
|
|
"__handles",
|
|
any_dict_ty,
|
|
c10::impl::GenericDict(
|
|
any_dict_ty->getKeyType(), any_dict_ty->getValueType()),
|
|
/*is_param=*/false);
|
|
|
|
// Methods.
|
|
|
|
// This is a helper function for creating a new instance of the
|
|
// backend class.
|
|
static const auto create_backend_ct = CodeTemplate(R"(
|
|
def __create_backend(self):
|
|
self.__backend = $name()
|
|
)");
|
|
TemplateEnv create_backend_te;
|
|
create_backend_te.s("name", qual_backend_name.qualifiedName());
|
|
loweredModule.define(
|
|
create_backend_ct.format(create_backend_te), loweredModuleResolver());
|
|
|
|
// getstate and setstate are for serialization/deserialization of
|
|
// the LoweredModule.
|
|
loweredModule.define(
|
|
R"(
|
|
def __getstate__(self):
|
|
return self.__method_compile_spec, self.__processed_module
|
|
)",
|
|
loweredModuleResolver());
|
|
|
|
loweredModule.define(
|
|
R"(
|
|
def __setstate__(self, state):
|
|
self.__method_compile_spec = state[0]
|
|
self.__processed_module = state[1]
|
|
self.__create_backend()
|
|
self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec)
|
|
)",
|
|
loweredModuleResolver());
|
|
|
|
// This is never called during compilation or execution, but is
|
|
// needed to generate the LoweredModule because we don't have access
|
|
// to an instance of the backend as a C++ object with which to call
|
|
// preprocess.
|
|
loweredModule.define(
|
|
R"(
|
|
def __preprocess(self, mod: Any, method_compile_spec: Dict[str, Any]):
|
|
self.__create_backend()
|
|
self.__processed_module = self.__backend.preprocess(mod, method_compile_spec)
|
|
)",
|
|
loweredModuleResolver());
|
|
|
|
// This loop generates one method on the LoweredModule for every key
|
|
// in method_compile_spec.
|
|
for (auto& e : method_compile_spec) {
|
|
std::string method_name = py::cast<std::string>(e.first);
|
|
static const auto method_ct = CodeTemplate(R"(
|
|
def $method(self${,def_inputs}):
|
|
typed_inputs: List[Any] = [${fwd_inputs,}]
|
|
$unpack, = self.__backend.execute(self.__handles["$method"], typed_inputs)
|
|
${refine,}
|
|
return $ret
|
|
)");
|
|
|
|
TemplateEnv method_te;
|
|
method_te.s("method", method_name);
|
|
auto method = orig_module.get_method(method_name);
|
|
auto& function = method.function();
|
|
auto& schema = function.getSchema();
|
|
|
|
// Generate the inputs for the function signature (def_inputs) and
|
|
// for passing to backend.execute (fwd_inputs).
|
|
std::vector<std::string> def_inputs, fwd_inputs;
|
|
for (const auto& arg : schema.arguments()) {
|
|
auto name = arg.name();
|
|
|
|
// Skip self since that is only and always present in the
|
|
// signature.
|
|
if (name == "self") {
|
|
continue;
|
|
}
|
|
|
|
auto default_value = arg.default_value();
|
|
|
|
if (arg.kwarg_only()) {
|
|
// If this is a kwarg, it needs to be emitted as keyword=value
|
|
// in the definition and keyword=keyword in the call to
|
|
// backend_execute.
|
|
TORCH_INTERNAL_ASSERT(default_value.has_value());
|
|
std::stringstream def_ss, fwd_ss;
|
|
def_ss << name << "=";
|
|
fwd_ss << name << "=" << name;
|
|
default_value->repr(def_ss, [](std::ostream&, const IValue&) -> bool {
|
|
return false;
|
|
});
|
|
def_inputs.emplace_back(def_ss.str());
|
|
fwd_inputs.emplace_back(fwd_ss.str());
|
|
} else {
|
|
// If this is not a kwarg, it should be emitted as is in the
|
|
// signature and the call to backend_execute.
|
|
def_inputs.emplace_back(name);
|
|
fwd_inputs.emplace_back(name);
|
|
}
|
|
}
|
|
|
|
// Generate a comma-delimited list of identifiers to unpack
|
|
// outputs, as well as a list of isinstance checks to make sure
|
|
// the backend returned the types it was supposed to.
|
|
std::stringstream out_ss, type_check_ss;
|
|
std::vector<std::string> type_checks;
|
|
TORCH_INTERNAL_ASSERT(schema.returns().size() == 1);
|
|
auto out_ty = schema.returns().at(0).type();
|
|
|
|
out_ss << "_0";
|
|
type_check_ss << "assert isinstance(_0, ";
|
|
|
|
auto out_tuple_ty = out_ty->cast<TupleType>();
|
|
|
|
if (out_tuple_ty) {
|
|
auto tuple_elements = out_tuple_ty->elements();
|
|
type_check_ss << tuple_elements[0]->str() << ")";
|
|
type_checks.emplace_back(type_check_ss.str());
|
|
for (unsigned i = 1, e = tuple_elements.size(); i < e; ++i) {
|
|
type_check_ss.str(std::string());
|
|
type_check_ss.clear();
|
|
out_ss << ", _" << i;
|
|
type_check_ss << "assert isinstance(_" << i << ", "
|
|
<< tuple_elements[i]->str() << ")";
|
|
type_checks.emplace_back(type_check_ss.str());
|
|
}
|
|
} else {
|
|
type_check_ss << out_ty->str() << ")";
|
|
type_checks.emplace_back(type_check_ss.str());
|
|
}
|
|
|
|
method_te.v("def_inputs", def_inputs);
|
|
method_te.v("fwd_inputs", fwd_inputs);
|
|
method_te.v("refine", type_checks);
|
|
method_te.s("unpack", out_ss.str());
|
|
|
|
// If the output type is a single element tuple then add an extra comma
|
|
// to ensure the final output maintains this type.
|
|
if (out_tuple_ty && out_tuple_ty->elements().size() == 1) {
|
|
out_ss << ",";
|
|
}
|
|
|
|
method_te.s("ret", out_ss.str());
|
|
|
|
loweredModule.define(
|
|
method_ct.format(method_te), loweredModuleResolver());
|
|
}
|
|
|
|
// Run preprocess so that __processed_module is set correctly before
|
|
// compilation.
|
|
loweredModule.run_method(
|
|
"__preprocess",
|
|
cloned_module._ivalue(),
|
|
toIValue(method_compile_spec, any_dict_ty).toGenericDict());
|
|
|
|
// Call __setstate__ to ensure that the returned Module is ready to
|
|
// run.
|
|
auto state = at::ivalue::Tuple::create(
|
|
toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
|
|
loweredModule.attr("__processed_module"));
|
|
loweredModule.run_method("__setstate__", state);
|
|
return loweredModule;
|
|
};
|
|
auto m = py::handle(module).cast<py::module>();
|
|
m.def(
|
|
"_jit_to_backend",
|
|
[=](const std::string& backend_name,
|
|
py::handle orig_module,
|
|
const py::dict& method_compile_spec) {
|
|
return py::module::import("torch.jit._recursive")
|
|
.attr("wrap_cpp_module")(codegen_lambda(
|
|
backend_name,
|
|
py::cast<Module>(orig_module.attr("_c")),
|
|
method_compile_spec));
|
|
});
|
|
}
|
|
} // namespace jit
|
|
} // namespace torch
|