mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: [Stacked commit, only review the last commit] This PR adds support for printing default values in python printing as well as the logic for parsing default values back in using the parser. For simplicity, this PR simply creates a subgraph of the constant expressions and then runs that graph to generate the defaults. A more lightweight approach should be possible later, but would require more machinery. To make reading code in the printer easier, this also add ir_views.h. Similar to tree_views.h these classes can provide views of some commonly used IR nodes that have complicated structure and common operations on that structure. Currently it has only read-only views for prim::If and prim::Loop, but we should eventually add helpers to manipulate If/Loop nodes as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14176 Differential Revision: D13198455 Pulled By: zdevito fbshipit-source-id: dc99ab9692804ccaedb60a55040c0b89ac7a6a6d
707 lines
26 KiB
C++
707 lines
26 KiB
C++
#include "torch/csrc/jit/script/init.h"
|
|
|
|
#include "torch/csrc/Device.h"
|
|
#include "torch/csrc/Dtype.h"
|
|
#include "torch/csrc/Layout.h"
|
|
#include "torch/csrc/jit/import.h"
|
|
#include "torch/csrc/jit/script/compiler.h"
|
|
|
|
#include "torch/csrc/jit/python_tracer.h"
|
|
#include "torch/csrc/jit/pybind_utils.h"
|
|
#include "torch/csrc/jit/constants.h"
|
|
#include "torch/csrc/jit/passes/to_batch.h"
|
|
#include "torch/csrc/jit/function_schema.h"
|
|
#include "torch/csrc/jit/script/parser.h"
|
|
#include "torch/csrc/jit/import_method.h"
|
|
#include "torch/csrc/jit/hooks_for_testing.h"
|
|
#include "torch/csrc/jit/passes/python_print.h"
|
|
|
|
#include <torch/csrc/api/include/torch/ordered_dict.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
|
|
#include <cstddef>
|
|
#include <memory>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <tuple>
|
|
#include <utility>
|
|
#include <vector>
|
|
#include <pybind11/functional.h>
|
|
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace script {
|
|
|
|
using ResolutionCallback = std::function<py::function(std::string)>;
|
|
using FunctionDefaults = std::unordered_map<std::string, py::object>;
|
|
|
|
static std::string typeString(py::handle h) {
|
|
return py::str(h.get_type().attr("__name__"));
|
|
}
|
|
|
|
inline std::shared_ptr<SugaredValue> toSimple(Value* v) {
|
|
return std::make_shared<SimpleValue>(v);
|
|
}
|
|
|
|
// NB: This should be the single entry-point for instantiating a SugaredValue
|
|
// from a Python object. If you are adding support for converting a new Python
|
|
// type, *add it in this function's implementation*.
|
|
std::shared_ptr<SugaredValue> toSugaredValue(
|
|
py::object obj,
|
|
Method& m,
|
|
SourceRange loc,
|
|
bool is_constant = false,
|
|
bool is_submodule = false);
|
|
|
|
struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
|
|
PythonValue(py::object self)
|
|
: self(std::move(self)) {}
|
|
|
|
FunctionSchema getSchema(const size_t n_args, const size_t n_binders) {
|
|
auto annotations = py::module::import("torch.jit.annotations");
|
|
auto signature = annotations.attr("get_signature")(self);
|
|
std::vector<Argument> args, rets;
|
|
// We may mutate this if we can determine the number of args from Python
|
|
// introspection.
|
|
size_t actual_n_args = n_args;
|
|
if (!signature.is_none()) {
|
|
std::vector<TypePtr> arg_types;
|
|
TypePtr ret_type;
|
|
std::tie(arg_types, ret_type) = py::cast<std::pair<std::vector<TypePtr>, TypePtr>>(signature);
|
|
args.reserve(arg_types.size());
|
|
size_t idx = 0; // Fake argument names by putting in the index
|
|
for (auto &arg_type : arg_types) {
|
|
args.push_back(Argument(std::to_string(idx++), std::move(arg_type), {}, {}, false));
|
|
}
|
|
rets.push_back(Argument("0", std::move(ret_type), {}, {}, false));
|
|
} else {
|
|
// Create a default signature using what information we have
|
|
|
|
// First see if we can introspect the number of function parameters
|
|
// irrespective of the presence of explicit type annotations
|
|
auto num_params = annotations.attr("get_num_params")(self);
|
|
if (!num_params.is_none()) {
|
|
// Return a signature with the correct number of params according to the
|
|
// Python function. The error handling in call() will catch any mismatch
|
|
// later.
|
|
actual_n_args = py::cast<size_t>(num_params);
|
|
}
|
|
// Construct the default signature: all arguments and returns will be
|
|
// DynamicType
|
|
args.reserve(actual_n_args);
|
|
for (size_t i=0; i < actual_n_args; ++i) {
|
|
args.push_back(Argument(std::to_string(i), DynamicType::get(), {}, {}, false));
|
|
}
|
|
TypePtr ret_type = DynamicType::get();
|
|
if(n_binders != 1) {
|
|
std::vector<TypePtr> tuple_values(n_binders, ret_type);
|
|
ret_type = TupleType::create(std::move(tuple_values));
|
|
}
|
|
rets.push_back(Argument("0", ret_type, {}, {}, false));
|
|
}
|
|
return FunctionSchema("", std::move(args), std::move(rets));
|
|
}
|
|
|
|
// call it like a function, e.g. `outputs = this(inputs)`
|
|
virtual std::shared_ptr<SugaredValue> call(SourceRange loc, Method & m, at::ArrayRef<NamedValue> inputs_, at::ArrayRef<NamedValue> attributes, size_t n_binders) override {
|
|
auto inputs = toValues(*m.graph(), inputs_);
|
|
auto schema = getSchema(inputs.size(), n_binders);
|
|
|
|
std::stringstream failure_messages;
|
|
c10::optional<MatchedSchema> matched_schema =
|
|
tryMatchSchema(schema, loc, *m.graph(), c10::nullopt, inputs_, attributes, failure_messages, /*conv_tensor_to_num*/true);
|
|
if (!matched_schema)
|
|
throw ErrorReport(loc) << failure_messages.str();
|
|
|
|
// Release the function object so we can wrap it in a PythonOp
|
|
py::object func = self;
|
|
std::string cconv(inputs.size(), 'd');
|
|
Node* new_node = m.graph()->insertNode(m.graph()->createPythonOp(
|
|
THPObjectPtr(func.release().ptr()), cconv, {}));
|
|
new_node->setSourceLocation(std::make_shared<SourceRange>(loc));
|
|
for(auto &i : matched_schema->inputs)
|
|
new_node->addInput(i);
|
|
|
|
std::vector<Value*> outputs;
|
|
for(auto & ret_arg : matched_schema->return_types) {
|
|
outputs.push_back(new_node->addOutput()->setType(ret_arg));
|
|
}
|
|
return std::make_shared<SimpleValue>(packOutputs(*m.graph(), outputs));
|
|
}
|
|
|
|
virtual std::string kind() const override {
|
|
std::stringstream ss;
|
|
ss << "python value of type '" << typeString(self) << "'";
|
|
return ss.str();
|
|
}
|
|
|
|
protected:
|
|
|
|
py::object getattr(SourceRange loc, const std::string& name) {
|
|
try {
|
|
return py::getattr(self, name.c_str());
|
|
} catch (py::error_already_set& e) {
|
|
throw ErrorReport(loc) << "object has no attribute " << name;
|
|
}
|
|
}
|
|
|
|
py::object self;
|
|
};
|
|
|
|
struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue {
|
|
explicit PythonModuleValue(py::object mod) : PythonValue(mod) {}
|
|
|
|
std::shared_ptr<SugaredValue> attr(
|
|
SourceRange loc,
|
|
Method& m,
|
|
const std::string& field) override {
|
|
py::object member = getattr(loc, field);
|
|
// note: is_constant = true because we consider that global properties
|
|
// on modules like math.pi or torch.float to be constants
|
|
// eventhough it is possible, though rare, for someone to mutate them
|
|
return toSugaredValue(member, m, loc, /*is_constant=*/true);
|
|
}
|
|
};
|
|
|
|
struct VISIBILITY_HIDDEN ConstantPythonTupleValue : public PythonValue {
|
|
explicit ConstantPythonTupleValue(py::object tup) : PythonValue(tup) {}
|
|
std::vector<std::shared_ptr<SugaredValue>> asTuple(
|
|
SourceRange loc,
|
|
Method& m,
|
|
c10::optional<size_t> size_hint = {}) override {
|
|
py::tuple tup = self;
|
|
std::vector<std::shared_ptr<SugaredValue>> result;
|
|
result.reserve(tup.size());
|
|
for (size_t i = 0; i < tup.size(); ++i) {
|
|
result.push_back(toSugaredValue(tup[i], m, loc, true));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
Value* asValue(
|
|
SourceRange loc,
|
|
Method& m) override {
|
|
std::vector<Value*> values;
|
|
for (auto sugared_item : asTuple(loc, m)) {
|
|
values.push_back(sugared_item->asValue(loc, m));
|
|
}
|
|
auto node = m.graph()->createTuple(values);
|
|
return m.graph()->insertNode(node)->output();
|
|
}
|
|
};
|
|
|
|
// defines how modules/methods behave inside the script subset.
|
|
// for now this does not have any interaction with python.
|
|
// in the future, we will add the ability to resolve `self.foo` to python
|
|
// {functions, modules, contants} so this SugaredValue is defined here
|
|
// anticipating we will eventually need to replace Module with a py::object
|
|
// holding the actual nn.Module class.
|
|
|
|
|
|
struct ModuleValue : public SugaredValue {
|
|
ModuleValue(std::shared_ptr<Module> module)
|
|
: module(std::move(module)) {}
|
|
|
|
virtual std::string kind() const override {
|
|
return "module";
|
|
}
|
|
|
|
// select an attribute on it, e.g. `this.field`
|
|
virtual std::shared_ptr<SugaredValue> attr(SourceRange loc, Method & m, const std::string& field) override {
|
|
if(NamedModule* v = module->find_module(field)) {
|
|
return std::make_shared<ModuleValue>(v->module);
|
|
} else if(Method* v = module->find_method(field)) {
|
|
return std::make_shared<MethodValue>(module, *v);
|
|
} else if(NamedParameter* v = module->find_parameter(field)) {
|
|
return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
|
|
}
|
|
// This can also be a call to a non-script module, or a plain
|
|
// python method. If so return this as a python value.
|
|
py::object py_module = py::cast(module);
|
|
if(py::object attr = py::getattr(py_module, field.c_str(), py::none())) {
|
|
if (py::isinstance<py::function>(attr) ||
|
|
py::isinstance(attr, py::module::import("torch.nn").attr("Module")) ||
|
|
py_module.attr("_constants_set").contains(field.c_str())) {
|
|
return toSugaredValue(attr, m, loc, true);
|
|
} else {
|
|
throw ErrorReport(loc) << "attribute '" << field << "' of type '" << typeString(attr) << "' is not usable in a script method (did you forget to add it __constants__?)";
|
|
}
|
|
}
|
|
throw ErrorReport(loc) << "module has no attribute '" << field << "'";
|
|
}
|
|
|
|
// call module.forward
|
|
virtual std::shared_ptr<SugaredValue> call(SourceRange loc, Method & caller, at::ArrayRef<NamedValue> inputs, at::ArrayRef<NamedValue> attributes, size_t n_binders) override {
|
|
return attr(loc, caller, "forward")->call(loc, caller, inputs, attributes, n_binders);
|
|
}
|
|
|
|
virtual std::vector<std::shared_ptr<SugaredValue>> asTuple(
|
|
SourceRange loc,
|
|
Method& m,
|
|
c10::optional<size_t> size_hint = {}) override {
|
|
py::object py_module = py::cast(module);
|
|
if(!py::isinstance(py_module, py::module::import("torch.jit").attr("_ConstModuleList")))
|
|
return SugaredValue::asTuple(loc, m, size_hint);
|
|
std::vector<std::shared_ptr<SugaredValue>> result;
|
|
for(py::handle module : py_module) {
|
|
py::object obj = py::reinterpret_borrow<py::object>(module);
|
|
result.push_back(toSugaredValue(
|
|
obj,
|
|
m,
|
|
loc,
|
|
/*is_constant =*/false,
|
|
/*is_submodule =*/true));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<Module> module;
|
|
};
|
|
|
|
struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
|
|
BooleanDispatchValue(py::dict dispatched_fn)
|
|
: dispatched_fn_(std::move(dispatched_fn)) {}
|
|
|
|
std::string kind() const override {
|
|
return "boolean dispatch";
|
|
}
|
|
|
|
std::vector<NamedValue> removeIndex(
|
|
at::ArrayRef<NamedValue> arr,
|
|
size_t index) {
|
|
auto sliced = arr.vec();
|
|
sliced.erase(sliced.begin() + index);
|
|
return sliced;
|
|
}
|
|
|
|
std::shared_ptr<SugaredValue> call(
|
|
SourceRange loc,
|
|
Method& caller,
|
|
at::ArrayRef<NamedValue> inputs,
|
|
at::ArrayRef<NamedValue> attributes,
|
|
size_t n_binders) override {
|
|
c10::optional<bool> result;
|
|
Graph& graph = *(caller.graph());
|
|
|
|
auto index = py::cast<size_t>(dispatched_fn_["index"]);
|
|
auto arg_name = py::str(dispatched_fn_["arg_name"]);
|
|
|
|
if (index < inputs.size()) {
|
|
// Dispatch flag is in arg list
|
|
result = constant_as<bool>(inputs.at(index).value(graph));
|
|
} else if (auto i = findInputWithName(arg_name, attributes)) {
|
|
// Dispatch flag is in kwargs
|
|
result = constant_as<bool>(attributes[*i].value(graph));
|
|
} else {
|
|
// Didn't find dispatch flag, so use default value
|
|
result = py::cast<bool>(dispatched_fn_["default"]);
|
|
}
|
|
|
|
if (!result) {
|
|
throw ErrorReport(loc) << "value for boolean dispatch was not constant";
|
|
}
|
|
|
|
std::shared_ptr<SugaredValue> value;
|
|
if (*result) {
|
|
value = toSugaredValue(dispatched_fn_["if_true"], caller, loc);
|
|
} else {
|
|
value = toSugaredValue(dispatched_fn_["if_false"], caller, loc);
|
|
}
|
|
return value->call(loc, caller, inputs, attributes, n_binders);
|
|
}
|
|
|
|
private:
|
|
py::dict dispatched_fn_;
|
|
};
|
|
|
|
std::shared_ptr<SugaredValue> toSugaredValue(
|
|
py::object obj,
|
|
Method& m,
|
|
SourceRange loc,
|
|
bool is_constant,
|
|
bool is_submodule) {
|
|
// directly create SimpleValues when possible, because they are first-class
|
|
// and can be re-assigned. Otherwise, this would be invalid:
|
|
// f = python_constant
|
|
// while ...
|
|
// f = f + 1
|
|
auto& g = *m.graph();
|
|
if (is_constant) {
|
|
if (py::isinstance<py::bool_>(obj)) {
|
|
return toSimple(g.insertConstant(py::cast<bool>(obj), loc));
|
|
} else if (py::isinstance<py::int_>(obj)) {
|
|
return toSimple(g.insertConstant(py::cast<int64_t>(obj), loc));
|
|
} else if (py::isinstance<py::float_>(obj)) {
|
|
return toSimple(g.insertConstant(py::cast<float>(obj), loc));
|
|
} else if (THPDevice_Check(obj.ptr())) {
|
|
auto device = (THPDevice*)obj.ptr();
|
|
std::vector<int64_t> v = {static_cast<int64_t>(device->device.type()),
|
|
device->device.index()};
|
|
return toSimple(g.insertConstant(std::move(v)));
|
|
} else if (THPLayout_Check(obj.ptr())) {
|
|
auto layout = (THPLayout*)obj.ptr();
|
|
const auto v = static_cast<int64_t>(layout->layout);
|
|
return toSimple(g.insertConstant(v, loc));
|
|
} else if (THPDtype_Check(obj.ptr())) {
|
|
auto dtype = (THPDtype*)(obj.ptr());
|
|
const auto v = static_cast<int64_t>(dtype->scalar_type);
|
|
return toSimple(g.insertConstant(v, loc));
|
|
} else if (py::isinstance<py::tuple>(obj)) {
|
|
return std::make_shared<ConstantPythonTupleValue>(obj);
|
|
}
|
|
}
|
|
|
|
auto weak_obj =
|
|
py::module::import("torch.jit").attr("_try_get_weak_module")(obj);
|
|
if (!weak_obj.is_none()) {
|
|
obj = weak_obj;
|
|
}
|
|
if (py::isinstance<Module>(obj)) {
|
|
auto mod = py::cast<std::shared_ptr<Module>>(obj);
|
|
// In the case that this Python object is not a submodule, inline *ONLY
|
|
// PURE* ScriptModules. This allows us to call arbitrary @script functions
|
|
// within a scripting context while still enforcing that parameters from
|
|
// stateful submodules are properly accounted for.
|
|
if (!is_submodule && mod->get_parameters().size() != 0) {
|
|
throw ErrorReport()
|
|
<< "Attempted to inline a Module with parameters. "
|
|
"Stateful modules to be inlined must be submodules of the callee.";
|
|
}
|
|
return std::make_shared<ModuleValue>(mod);
|
|
} else if (py::isinstance<py::module>(obj)) {
|
|
return std::make_shared<PythonModuleValue>(obj);
|
|
} else if (obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr()) {
|
|
return std::make_shared<ForkValue>();
|
|
} else if (obj.ptr() == py::module::import("torch.jit").attr("annotate").ptr()) {
|
|
return std::make_shared<AnnotateValue>();
|
|
}
|
|
|
|
py::object builtin_name = py::module::import("torch.jit").attr("_find_builtin")(obj);
|
|
if (!builtin_name.is_none()) {
|
|
return std::make_shared<BuiltinFunction>(
|
|
Symbol::fromQualString(py::str(builtin_name)), c10::nullopt);
|
|
}
|
|
|
|
if (py::isinstance<py::function>(obj)) {
|
|
auto compiled_fn =
|
|
py::module::import("torch.jit").attr("_try_compile_weak_script")(obj);
|
|
if (!compiled_fn.is(py::none())) {
|
|
auto mod = py::cast<std::shared_ptr<Module>>(compiled_fn);
|
|
return std::make_shared<ModuleValue>(mod);
|
|
}
|
|
}
|
|
|
|
py::object dispatched_fn =
|
|
py::module::import("torch.jit").attr("_try_get_dispatched_fn")(obj);
|
|
if (!dispatched_fn.is_none()) {
|
|
return std::make_shared<BooleanDispatchValue>(std::move(dispatched_fn));
|
|
}
|
|
return std::make_shared<PythonValue>(obj);
|
|
}
|
|
|
|
py::object unpackVariableTensorList(std::vector<at::Tensor> outputs) {
|
|
// if we don't tell pybind these are variables it chokes on the
|
|
// conversion.
|
|
// TODO: fix conversions to be sane and make sure this works.
|
|
if (outputs.size() == 0) {
|
|
return py::none();
|
|
} else if (outputs.size() == 1) {
|
|
return py::cast(autograd::as_variable_ref(outputs[0]));
|
|
} else {
|
|
py::tuple tuple(outputs.size());
|
|
for(size_t i = 0; i < outputs.size(); i++) {
|
|
tuple[i] = py::cast(autograd::as_variable_ref(outputs[i]));
|
|
}
|
|
return tuple;
|
|
}
|
|
}
|
|
|
|
static void gatherParametersAndBuffers(std::vector<at::Tensor*> & values, const Module & m) {
|
|
for(auto & param : m.get_parameters()) {
|
|
values.push_back(param->slot());
|
|
}
|
|
for(const auto & sub : m.get_modules()) {
|
|
gatherParametersAndBuffers(values, *sub->module);
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
|
|
Resolver pythonResolver(ResolutionCallback rcb) {
|
|
return [rcb](const std::string& name, Method& m, const SourceRange& loc)
|
|
-> std::shared_ptr<SugaredValue> {
|
|
AutoGIL ag;
|
|
py::object obj = rcb(name);
|
|
if (obj.is(py::none())) {
|
|
return nullptr;
|
|
}
|
|
return toSugaredValue(obj, m, loc);
|
|
};
|
|
}
|
|
|
|
}
|
|
|
|
FunctionSchema getSchemaWithNameAndDefaults(
|
|
const SourceRange& range,
|
|
const FunctionSchema schema,
|
|
at::optional<std::string> new_name,
|
|
const FunctionDefaults& default_args) {
|
|
std::vector<Argument> new_args;
|
|
for (auto& arg : schema.arguments()) {
|
|
auto it = default_args.find(arg.name());
|
|
if (it != default_args.end()) {
|
|
try {
|
|
IValue value = toIValue(it->second, arg.type());
|
|
new_args.push_back(
|
|
Argument(arg.name(), arg.type(), arg.N(), value, arg.kwarg_only()));
|
|
} catch (py::cast_error& e) {
|
|
throw ErrorReport(range)
|
|
<< "Expected a default value of type " << arg.type()->str()
|
|
<< " on parameter \"" << arg.name() << "\"";
|
|
}
|
|
} else {
|
|
new_args.push_back(arg);
|
|
}
|
|
}
|
|
|
|
return FunctionSchema(
|
|
new_name.value_or(schema.name()),
|
|
new_args,
|
|
schema.returns(),
|
|
schema.is_vararg(),
|
|
schema.is_varret());
|
|
}
|
|
|
|
void initJitScriptBindings(PyObject* module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
|
|
// torch.jit.ScriptModule is a subclass of this C++ object.
|
|
// Methods here are prefixed with _ since they should not be
|
|
// public.
|
|
py::class_<Module, std::shared_ptr<Module>>(m, "ScriptModule")
|
|
.def(py::init<>())
|
|
.def("save", [](std::shared_ptr<Module> m, const std::string& filename) {
|
|
m->save(filename);
|
|
})
|
|
.def("save_to_buffer", [](std::shared_ptr<Module> m) {
|
|
std::ostringstream buf;
|
|
m->save(buf);
|
|
return py::bytes(buf.str());
|
|
})
|
|
.def("_set_optimized", &Module::set_optimized)
|
|
.def(
|
|
"_define",
|
|
[](std::shared_ptr<Module> m,
|
|
const std::string& script,
|
|
ResolutionCallback rcb, bool has_self) {
|
|
auto self = has_self ? std::make_shared<ModuleValue>(m) : nullptr;
|
|
return defineMethodsInModule(m, script, pythonResolver(rcb), self);
|
|
})
|
|
.def("_create_methods", [](std::shared_ptr<Module> m,
|
|
const std::vector<Def>& defs,
|
|
const std::vector<ResolutionCallback>& rcbs,
|
|
const std::vector<FunctionDefaults>& defaults) {
|
|
std::vector<Resolver> resolvers;
|
|
for(auto & callback : rcbs) {
|
|
resolvers.push_back(pythonResolver(callback));
|
|
}
|
|
defineMethodsInModule(
|
|
m,
|
|
defs,
|
|
resolvers,
|
|
std::make_shared<ModuleValue>(m));
|
|
|
|
// Stitch in default arguments for each Def if provided
|
|
auto defaults_it = defaults.begin();
|
|
auto defs_it = defs.begin();
|
|
while (defs_it != defs.end()) {
|
|
auto& method = m->get_method((*defs_it).name().name());
|
|
method.setSchema(getSchemaWithNameAndDefaults(
|
|
defs_it->range(), method.getSchema(), at::nullopt, *defaults_it));
|
|
++defs_it;
|
|
++defaults_it;
|
|
}
|
|
didFinishEmitModule(m);
|
|
})
|
|
.def("_get_method",
|
|
[](Module& self, const std::string& name) -> const Method& {
|
|
return self.get_method(name);
|
|
}, py::return_value_policy::reference_internal)
|
|
.def("_register_parameter", &Module::register_parameter)
|
|
.def("_register_module", &Module::register_module)
|
|
.def("_set_parameter", &Module::set_parameter)
|
|
.def("_get_parameter", &Module::get_parameter)
|
|
.def("_get_module", &Module::get_module)
|
|
.def("_get_modules", [](Module& self) -> py::tuple {
|
|
auto & modules = self.get_modules();
|
|
py::tuple result(modules.size());
|
|
for(size_t i = 0; i < modules.size(); ++i) {
|
|
auto & item = modules[i];
|
|
result[i] = std::make_pair(item.key(), item.value());
|
|
}
|
|
return result;
|
|
})
|
|
.def("_get_parameters", [](Module& self) -> py::tuple {
|
|
auto & parameters = self.get_parameters();
|
|
py::tuple result(parameters.size());
|
|
for(size_t i = 0; i < parameters.size(); ++i) {
|
|
auto & p = parameters[i];
|
|
py::tuple r(3);
|
|
result[i] = std::make_tuple(
|
|
p.key(),
|
|
autograd::as_variable_ref(*p->slot()),
|
|
p->is_buffer);
|
|
|
|
}
|
|
return result;
|
|
})
|
|
.def("_has_parameter", [](Module& self, const std::string& name) {
|
|
if(auto r = self.find_parameter(name)) {
|
|
return !r->is_buffer;
|
|
}
|
|
return false;
|
|
})
|
|
.def("_has_buffer", [](Module& self, const std::string& name) {
|
|
if(auto r = self.find_parameter(name)) {
|
|
return r->is_buffer;
|
|
}
|
|
return false;
|
|
})
|
|
.def("_has_module", [](Module& self, const std::string& name) {
|
|
return bool(self.find_module(name));
|
|
})
|
|
.def("_has_method", [](Module& self, const std::string& name) {
|
|
return bool(self.find_method(name));
|
|
})
|
|
.def("_method_names", [](Module& self) {
|
|
using Item = torch::OrderedDict<std::string, std::unique_ptr<Method>>::Item;
|
|
return fmap(self.get_methods(), [](const Item & item) {
|
|
return (*item)->name();
|
|
});
|
|
})
|
|
.def("_create_method_from_graph", [](
|
|
Module& self,
|
|
const std::string& name,
|
|
std::shared_ptr<Graph> graph
|
|
){
|
|
self.create_method(name, std::move(graph), {});
|
|
})
|
|
.def("_create_method_from_trace", [](
|
|
std::shared_ptr<Module> self,
|
|
const std::string& name,
|
|
py::function func,
|
|
py::tuple input_tuple,
|
|
py::function var_lookup_fn) {
|
|
// prereq: Module's buffers and parameters are unique
|
|
// this was ensured in python before calling this function
|
|
std::vector<at::Tensor*> parameters;
|
|
gatherParametersAndBuffers(parameters, *self);
|
|
Stack inputs = toStack(input_tuple);
|
|
for(at::Tensor* param : parameters) {
|
|
inputs.emplace_back(*param);
|
|
}
|
|
auto graph = tracer::createGraphByTracing(func, inputs, var_lookup_fn, input_tuple.size());
|
|
self->create_method(name, std::move(graph), std::move(parameters));
|
|
didFinishEmitModule(self);
|
|
})
|
|
.def("graph_for", [](py::args args, py::kwargs kwargs) {
|
|
// [pybind11 varargs] note: old version of pybind11 have a bug that leaks memory
|
|
// when py::args is mixed with positional arguments
|
|
// https://github.com/pybind/pybind11/pull/1216
|
|
// we work around this by not mixing positional arguments with varargs
|
|
Module& self = py::cast<Module&>(args[0]);
|
|
if (self.find_method("forward")) {
|
|
Method & m = self.get_method("forward");
|
|
return m.graph_for(
|
|
createStackForSchema(m.getSchema(), tuple_slice(std::move(args), 1), std::move(kwargs)));
|
|
}
|
|
throw std::runtime_error("Attempted to call graph_for on a Module without a compiled forward()");
|
|
})
|
|
.def("get_debug_state", [](Module& self) {
|
|
if (self.find_method("forward")) {
|
|
Method & m = self.get_method("forward");
|
|
return m.getDebugState();
|
|
}
|
|
throw std::runtime_error("Attempted to call get_debug_state on a Module without a compiled forward()");
|
|
})
|
|
.def("debug_disable_autodiff_subgraph_inlining", [](Module& self) {
|
|
if (self.find_method("forward")) {
|
|
Method & m = self.get_method("forward");
|
|
m.debugDisableAutodiffSubgraphInlining();
|
|
}
|
|
})
|
|
.def("forward", [](py::args args, py::kwargs kwargs) {
|
|
// We implement this in C++ to avoid incurring the pybind11 dispatch
|
|
// overhead twice: once to call into the method lookup for "forward"
|
|
// and once to actually invoke the method.
|
|
//
|
|
// There is a thin wrapper on top of this method in the C++ version of
|
|
// ScriptModule.
|
|
|
|
// see: [pybind11 varargs]
|
|
Module& self = py::cast<Module&>(args[0]);
|
|
return invokeScriptMethodFromPython(self.get_method("forward"), tuple_slice(std::move(args), 1), std::move(kwargs));
|
|
});
|
|
|
|
py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
|
|
.def("graph", [&](Method& self) {
|
|
return self.graph();
|
|
})
|
|
.def("__call__", [](py::args args, py::kwargs kwargs) {
|
|
// see: [pybind11 varargs]
|
|
Method& method = py::cast<Method&>(args[0]);
|
|
return invokeScriptMethodFromPython(method, tuple_slice(std::move(args), 1), std::move(kwargs));
|
|
})
|
|
.def_property_readonly("graph", [](Method& m) {
|
|
return m.graph();
|
|
})
|
|
.def("propagate_shapes", &Method::propagate_shapes)
|
|
.def("propagate_and_assign_input_and_output_shapes", &Method::propagate_and_assign_input_and_output_shapes)
|
|
.def("params", &Method::params)
|
|
.def("graph_for", [](py::args args, py::kwargs kwargs) {
|
|
// see: [pybind11 varargs]
|
|
Method& self = py::cast<Method&>(args[0]);
|
|
return self.graph_for(createStackForSchema(self.getSchema(), tuple_slice(std::move(args), 1), std::move(kwargs)));
|
|
})
|
|
.def("debug_disable_autodiff_subgraph_inlining", &Method::debugDisableAutodiffSubgraphInlining)
|
|
.def("pretty_print_schema", &Method::pretty_print_schema)
|
|
.def("python_print", [](Method &m) {
|
|
std::ostringstream oss;
|
|
std::vector<at::Tensor> constants = PythonPrint(oss, m, true);
|
|
return std::make_pair(oss.str(), std::move(constants));
|
|
});
|
|
|
|
m.def("_jit_script_compile", [](std::shared_ptr<Module> mod, const Def &def, ResolutionCallback rcb, FunctionDefaults defaults) {
|
|
auto def_f = def.withName("forward");
|
|
defineMethodsInModule(mod, {def_f}, {pythonResolver(rcb)}, nullptr);
|
|
auto& method = mod->get_method("forward");
|
|
method.setSchema(getSchemaWithNameAndDefaults(
|
|
def.range(), method.getSchema(), def.name().name(), defaults));
|
|
didFinishEmitModule(mod);
|
|
return mod;
|
|
});
|
|
|
|
m.def("parse_type_comment", [](const std::string& comment) {
|
|
Parser p(comment);
|
|
return Decl(p.parseTypeComment(true));
|
|
});
|
|
|
|
m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment);
|
|
m.def("import_ir_module", [](ModuleLookup module_lookup, const std::string& filename) {
|
|
import_ir_module(module_lookup, filename);
|
|
});
|
|
m.def("import_ir_module_from_buffer", [](ModuleLookup module_lookup, const std::string& buffer) {
|
|
std::istringstream in(buffer);
|
|
import_ir_module(module_lookup, in);
|
|
});
|
|
m.def("_jit_import_method", import_method);
|
|
m.def("_jit_set_emit_module_hook", setEmitModuleHook);
|
|
}
|
|
|
|
} // namespace script
|
|
} // namespace jit
|
|
} // namespace torch
|