Factor Module into Object and Module

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29500

Test Plan: Imported from OSS

Differential Revision: D18463064

Pulled By: jamesr66a

fbshipit-source-id: d37bef242a8626593d4b8754042152cfc0f0acb2
This commit is contained in:
James Reed 2019-11-17 22:56:49 -08:00 committed by Facebook Github Bot
parent 14946a8891
commit 18bdf97dbb
24 changed files with 368 additions and 301 deletions

View File

@ -44,7 +44,7 @@ void dump_opnames(const script::Module& m, std::unordered_set<std::string>& opna
} }
} }
for (const auto& sub_m : m.children()) { for (const auto& sub_m : m.children()) {
std::cout << "sub module name: " << sub_m.name().qualifiedName() << std::endl; std::cout << "sub module name: " << sub_m.type()->name()->qualifiedName() << std::endl;
dump_opnames(sub_m, opnames); dump_opnames(sub_m, opnames);
} }
} }

View File

@ -444,6 +444,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp ${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp
${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp ${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp ${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
${TORCH_SRC_DIR}/csrc/jit/script/object.cpp
${TORCH_SRC_DIR}/csrc/jit/script/jit_exception.cpp ${TORCH_SRC_DIR}/csrc/jit/script/jit_exception.cpp
${TORCH_SRC_DIR}/csrc/jit/source_range_serialization.cpp ${TORCH_SRC_DIR}/csrc/jit/source_range_serialization.cpp
${TORCH_SRC_DIR}/csrc/jit/tracer.cpp ${TORCH_SRC_DIR}/csrc/jit/tracer.cpp

View File

@ -89,12 +89,12 @@ void testScriptObject() {
Module m2("m2"); Module m2("m2");
std::vector<at::Tensor> constantTable; std::vector<at::Tensor> constantTable;
import_libs( import_libs(
m1.class_compilation_unit(), m1._ivalue()->compilation_unit(),
"__torch__.FooTest", "__torch__.FooTest",
std::make_shared<Source>(classSrcs1), std::make_shared<Source>(classSrcs1),
constantTable); constantTable);
import_libs( import_libs(
m2.class_compilation_unit(), m2._ivalue()->compilation_unit(),
"__torch__.FooTest", "__torch__.FooTest",
std::make_shared<Source>(classSrcs2), std::make_shared<Source>(classSrcs2),
constantTable); constantTable);

View File

@ -64,7 +64,7 @@ void testModuleInterfaceSerialization() {
parentMod.register_attribute( parentMod.register_attribute(
"subMod", "subMod",
cu->get_interface("__torch__.OneForward"), cu->get_interface("__torch__.OneForward"),
subMod.module_object(), subMod._ivalue(),
/*is_parameter=*/false); /*is_parameter=*/false);
parentMod.define(parentForward, nativeResolver()); parentMod.define(parentForward, nativeResolver());
ASSERT_TRUE(parentMod.hasattr("subMod")); ASSERT_TRUE(parentMod.hasattr("subMod"));

View File

@ -164,6 +164,7 @@ libtorch_sources = [
"torch/csrc/jit/hooks_for_testing.cpp", "torch/csrc/jit/hooks_for_testing.cpp",
"torch/csrc/jit/script/builtin_functions.cpp", "torch/csrc/jit/script/builtin_functions.cpp",
"torch/csrc/jit/script/module.cpp", "torch/csrc/jit/script/module.cpp",
"torch/csrc/jit/script/object.cpp",
"torch/csrc/jit/tracer.cpp", "torch/csrc/jit/tracer.cpp",
"torch/csrc/jit/fuser/kernel_cache.cpp", "torch/csrc/jit/fuser/kernel_cache.cpp",
"torch/csrc/jit/fuser/compiler.cpp", "torch/csrc/jit/fuser/compiler.cpp",

View File

@ -26,7 +26,7 @@ namespace serialize {
class TORCH_API OutputArchive final { class TORCH_API OutputArchive final {
public: public:
explicit OutputArchive(std::shared_ptr<jit::script::CompilationUnit> cu); explicit OutputArchive(std::shared_ptr<jit::script::CompilationUnit> cu);
explicit OutputArchive() : cu_(std::make_shared<jit::script::CompilationUnit>()) {} explicit OutputArchive() : cu_(std::make_shared<jit::script::CompilationUnit>()), module_("__torch__.Module", cu_) {}
// Move is allowed. // Move is allowed.
OutputArchive(OutputArchive&&) = default; OutputArchive(OutputArchive&&) = default;

View File

@ -16,7 +16,7 @@
namespace torch { namespace torch {
namespace serialize { namespace serialize {
InputArchive::InputArchive() {} InputArchive::InputArchive() : module_("Module", std::make_shared<jit::script::CompilationUnit>()) {}
void InputArchive::read(const std::string& key, c10::IValue& ivalue) { void InputArchive::read(const std::string& key, c10::IValue& ivalue) {
ivalue = module_.attr(key); ivalue = module_.attr(key);

View File

@ -546,7 +546,7 @@ class ScriptModuleSerializer {
C10_LOG_API_USAGE_ONCE("torch.script.save"); C10_LOG_API_USAGE_ONCE("torch.script.save");
writeExtraFiles(module, extra_files); writeExtraFiles(module, extra_files);
// Serialize the model object // Serialize the model object
writeArchive("data", module.module_object()); writeArchive("data", module._ivalue());
// Then we werialize all code info. // Then we werialize all code info.
writeCode(module.type()); writeCode(module.type());
// The tensor constants from the code are written to a separate archive // The tensor constants from the code are written to a separate archive

View File

@ -257,9 +257,9 @@ void ScriptModuleDeserializer::LEGACY_moduleSetState(
// TODO: once modules are first class in the interpreter and methods are not // TODO: once modules are first class in the interpreter and methods are not
// lowered, change this to `module->run_method("__setstate__", {state});` // lowered, change this to `module->run_method("__setstate__", {state});`
if (setstate->num_inputs() == 1) { if (setstate->num_inputs() == 1) {
setstate->run({module.module_object()}); setstate->run({module._ivalue()});
} else if (setstate->num_inputs() == 2) { } else if (setstate->num_inputs() == 2) {
setstate->run({module.module_object(), state}); setstate->run({module._ivalue(), state});
} else { } else {
AT_ERROR("Unexpected schema on '__setstate__'"); AT_ERROR("Unexpected schema on '__setstate__'");
} }
@ -348,11 +348,11 @@ script::Module ScriptModuleDeserializer::LEGACY_convertModule(
LEGACY_pickled_ivalues_.at(module_def.get_state_attribute_id())); LEGACY_pickled_ivalues_.at(module_def.get_state_attribute_id()));
} }
const ClassTypePtr& module_type = module.module_object()->type(); const ClassTypePtr& module_type = module._ivalue()->type();
for (size_t i = 0, N = module_type->numAttributes(); i < N; ++i) { for (size_t i = 0, N = module_type->numAttributes(); i < N; ++i) {
// Verify that all the non-optional attributes have been initialized // Verify that all the non-optional attributes have been initialized
// TODO: Issue #20497 // TODO: Issue #20497
const IValue& v = module.module_object()->getSlot(i); const IValue& v = module._ivalue()->getSlot(i);
if (module_type->getAttribute(i)->kind() != TypeKind::OptionalType) { if (module_type->getAttribute(i)->kind() != TypeKind::OptionalType) {
TORCH_CHECK( TORCH_CHECK(
!v.isNone(), !v.isNone(),

View File

@ -193,7 +193,7 @@ struct SourceImporterImpl : public Resolver,
const script::Module& mod, const script::Module& mod,
const std::shared_ptr<Source>& src) { const std::shared_ptr<Source>& src) {
auto self = SimpleSelf(mod.type()); auto self = SimpleSelf(mod.type());
c10::QualifiedName prefix = mod.name(); c10::QualifiedName prefix = *mod.type()->name();
Parser p(src); Parser p(src);
parsePossibleVersionNumber(p.lexer()); parsePossibleVersionNumber(p.lexer());

View File

@ -267,7 +267,7 @@ void initJITBindings(PyObject* module) {
.def( .def(
"_jit_pass_lower_graph", "_jit_pass_lower_graph",
[](std::shared_ptr<Graph>& graph, const script::Module& self) { [](std::shared_ptr<Graph>& graph, const script::Module& self) {
return LowerGraph(*graph, self.module_object()); return LowerGraph(*graph, self._ivalue());
}) })
.def("_jit_pass_loop_unrolling", UnrollLoops) .def("_jit_pass_loop_unrolling", UnrollLoops)
.def( .def(

View File

@ -30,7 +30,7 @@ void fillQConfigMap(
} else { } else {
qconfig = parent_qconfig; qconfig = parent_qconfig;
} }
map[module.module_object()] = qconfig; map[module._ivalue()] = qconfig;
for (const script::NameModule& s : module.named_children()) { for (const script::NameModule& s : module.named_children()) {
std::string child_key; std::string child_key;
@ -39,8 +39,7 @@ void fillQConfigMap(
} else { } else {
child_key = key + "." + s.name; child_key = key + "." + s.name;
} }
fillQConfigMap( fillQConfigMap(s.value._ivalue(), qconfig_dict, map, child_key, qconfig);
s.value.module_object(), qconfig_dict, map, child_key, qconfig);
} }
} }
@ -184,7 +183,8 @@ Node* InsertObserversHelper::insertObserverFor(
return nullptr; return nullptr;
} }
script::Module observer_module; script::Module observer_module(
"Module", std::make_shared<script::CompilationUnit>());
if (isWeightOfConvOrLinear(v)) { if (isWeightOfConvOrLinear(v)) {
TORCH_CHECK(v->uses().size() == 1, "We only support weight being used by one node."); TORCH_CHECK(v->uses().size() == 1, "We only support weight being used by one node.");
observer_module = std::get<1>(qconfig); observer_module = std::get<1>(qconfig);
@ -275,7 +275,7 @@ graph(%input, %weight, %bias, %4):
void InsertObserversHelper::insertObservers( void InsertObserversHelper::insertObservers(
script::Module& module, script::Module& module,
const std::string& method_name) { const std::string& method_name) {
if (!module_qconfig_map_.count(module.module_object())) { if (!module_qconfig_map_.count(module._ivalue())) {
// the module is added by us, e.g.: observer module // the module is added by us, e.g.: observer module
return; return;
} }
@ -304,7 +304,7 @@ void InsertObserversHelper::insertObservers(
for (size_t idx = 1; idx < method.num_inputs(); ++idx) { for (size_t idx = 1; idx < method.num_inputs(); ++idx) {
auto& v = graph->inputs()[idx]; auto& v = graph->inputs()[idx];
if (!values_to_skip_.count(v) && valueNeedsToBeQuantized(v)) { if (!values_to_skip_.count(v) && valueNeedsToBeQuantized(v)) {
auto qconfig = module_qconfig_map_.at(module.module_object()); auto qconfig = module_qconfig_map_.at(module._ivalue());
if (qconfig) { if (qconfig) {
auto observer_node = auto observer_node =
insertObserverFor(v, v->owningGraph(), module, qconfig.value()); insertObserverFor(v, v->owningGraph(), module, qconfig.value());
@ -339,7 +339,8 @@ void InsertObserversHelper::insertObservers(
// the child module. // the child module.
auto module_instance = n->inputs()[0]; auto module_instance = n->inputs()[0];
auto module_method_name = n->s(attr::name); auto module_method_name = n->s(attr::name);
script::Module callee_module; script::Module callee_module(
"Module", std::make_shared<script::CompilationUnit>());
if (module_instance->node()->kind() == prim::GetAttr) { if (module_instance->node()->kind() == prim::GetAttr) {
auto child_module_name = module_instance->node()->s(attr::name); auto child_module_name = module_instance->node()->s(attr::name);
callee_module = module.attr(child_module_name).toModule(); callee_module = module.attr(child_module_name).toModule();
@ -365,7 +366,7 @@ void InsertObserversHelper::insertObservers(
// Actually add observer nodes. // Actually add observer nodes.
for (Value* v : values_to_observe) { for (Value* v : values_to_observe) {
auto qconfig = module_qconfig_map_.at(module.module_object()); auto qconfig = module_qconfig_map_.at(module._ivalue());
// Skip inserting observer if no qconfig is specified // Skip inserting observer if no qconfig is specified
if (qconfig) { if (qconfig) {
insertObserverFor(v, v->owningGraph(), module, qconfig.value()); insertObserverFor(v, v->owningGraph(), module, qconfig.value());
@ -455,7 +456,7 @@ class QuantizeHelper {
// O(N) where N is number of observer moduels with this optimization // O(N) where N is number of observer moduels with this optimization
for (int64_t i = observer_modules_to_remove_.size() - 1; i >= 0; --i) { for (int64_t i = observer_modules_to_remove_.size() - 1; i >= 0; --i) {
auto observer_name = observer_modules_to_remove_[i]; auto observer_name = observer_modules_to_remove_[i];
module_.module_object()->unsafeRemoveAttr(observer_name); module_._ivalue()->unsafeRemoveAttr(observer_name);
module_.type()->unsafeRemoveAttribute(observer_name); module_.type()->unsafeRemoveAttribute(observer_name);
} }
// Destroy observer forward calls // Destroy observer forward calls
@ -826,7 +827,8 @@ graph(%self, %x):
script::Method method = current.get_method("forward"); script::Method method = current.get_method("forward");
GRAPH_DUMP( GRAPH_DUMP(
current.name().name() + "::forward() before Conv2d-BatchNorm2d folding", current.type()->name()->name() +
"::forward() before Conv2d-BatchNorm2d folding",
method.graph()); method.graph());
const auto& matches = findPatternMatches(pattern_graph, *method.graph()); const auto& matches = findPatternMatches(pattern_graph, *method.graph());

View File

@ -458,7 +458,7 @@ inline IValue toIValue(
auto classType = type->expect<ClassType>(); auto classType = type->expect<ClassType>();
if (auto mod = script::as_module(py::cast<py::object>(obj))) { if (auto mod = script::as_module(py::cast<py::object>(obj))) {
// if obj is already a ScriptModule, just return its ivalue // if obj is already a ScriptModule, just return its ivalue
return mod.value().module_object(); return mod.value()._ivalue();
} }
// otherwise is a normal class object, we create a fresh // otherwise is a normal class object, we create a fresh
// ivalue::Object to use from the py object. // ivalue::Object to use from the py object.
@ -487,7 +487,7 @@ inline IValue toIValue(
IValue res; IValue res;
if (auto mod = script::as_module(py::cast<py::object>(obj))) { if (auto mod = script::as_module(py::cast<py::object>(obj))) {
classType = mod.value().type(); classType = mod.value().type();
res = mod.value().module_object(); res = mod.value()._ivalue();
} else { } else {
// We inspect the value to found the compiled TorchScript class // We inspect the value to found the compiled TorchScript class
// and then create a ivalue::Object from that class type. // and then create a ivalue::Object from that class type.
@ -926,7 +926,7 @@ inline py::object invokeScriptMethodFromPython(
script::Method& callee, script::Method& callee,
tuple_slice args, tuple_slice args,
py::kwargs kwargs) { py::kwargs kwargs) {
auto self = callee.owner().module_object(); auto self = callee.owner()._ivalue();
return runAndInsertCall( return runAndInsertCall(
callee.function(), callee.function(),
args, args,

View File

@ -725,7 +725,8 @@ void initPythonIRBindings(PyObject* module_) {
py::class_<ClassType, Type, std::shared_ptr<ClassType>>(m, "ClassType") py::class_<ClassType, Type, std::shared_ptr<ClassType>>(m, "ClassType")
.def(py::init([](const std::string& qualified_name) { .def(py::init([](const std::string& qualified_name) {
return get_python_cu()->get_class(c10::QualifiedName(qualified_name)); return get_python_cu()->get_class(c10::QualifiedName(qualified_name));
})); }))
.def("name", [](ClassType& self) { return self.name()->name(); });
py::class_<InterfaceType, Type, std::shared_ptr<InterfaceType>>( py::class_<InterfaceType, Type, std::shared_ptr<InterfaceType>>(
m, "InterfaceType") m, "InterfaceType")
.def(py::init([](const std::string& qualified_name) { .def(py::init([](const std::string& qualified_name) {

View File

@ -390,9 +390,10 @@ void addFunctionToModule(Module& module, const StrongFunctionPtr& func) {
// Make a graph with a fake self argument // Make a graph with a fake self argument
auto graph = func.function_->graph()->copy(); auto graph = func.function_->graph()->copy();
auto v = graph->insertInput(0, "self"); auto v = graph->insertInput(0, "self");
v->setType(module.module_object()->type()); v->setType(module._ivalue()->type());
const auto name = QualifiedName(module.name(), "forward"); const auto name = QualifiedName(*module.type()->name(), "forward");
auto method = module.class_compilation_unit()->create_function(name, graph); auto method =
module._ivalue()->compilation_unit()->create_function(name, graph);
module.type()->addMethod(method); module.type()->addMethod(method);
} }
@ -403,7 +404,7 @@ bool ivalue_tags_match(const Module& lhs, const Module& rhs) {
IValue b; IValue b;
}; };
std::unordered_set<const void*> visited; std::unordered_set<const void*> visited;
std::vector<Work> work = {{lhs.module_object(), rhs.module_object()}}; std::vector<Work> work = {{lhs._ivalue(), rhs._ivalue()}};
while (!work.empty()) { while (!work.empty()) {
Work item = work.back(); Work item = work.back();
work.pop_back(); work.pop_back();
@ -501,7 +502,7 @@ struct slot_dict_impl {
static void bind(const py::module& m, const char* name) { static void bind(const py::module& m, const char* name) {
py::class_<slot_dict_impl<Policy>>(m, name) py::class_<slot_dict_impl<Policy>>(m, name)
.def(py::init( .def(py::init(
[](Module& m) { return slot_dict_impl<Policy>(m.module_object()); })) [](Module& m) { return slot_dict_impl<Policy>(m._ivalue()); }))
.def("contains", &slot_dict_impl<Policy>::contains) .def("contains", &slot_dict_impl<Policy>::contains)
.def("items", &slot_dict_impl<Policy>::items) .def("items", &slot_dict_impl<Policy>::items)
.def("setattr", &slot_dict_impl<Policy>::setattr) .def("setattr", &slot_dict_impl<Policy>::setattr)
@ -562,10 +563,56 @@ void initJitScriptBindings(PyObject* module) {
// follows. // follows.
py::bind_map<ExtraFilesMap>(m, "ExtraFilesMap"); py::bind_map<ExtraFilesMap>(m, "ExtraFilesMap");
py::class_<Object>(m, "ScriptObject")
.def("_type", [](Module& m) { return m.type(); })
.def(
"_get_method",
[](Object& self, const std::string& name) -> Method {
return self.get_method(name);
},
py::keep_alive<0, 1>())
.def(
"setattr",
[](Object& self, const std::string& name, py::object value) {
TypePtr type = self.type()->getAttribute(name);
TORCH_CHECK(type, "Module has no attribute '", name, "'");
auto ivalue = toIValue(std::move(value), type);
self.setattr(name, ivalue);
})
.def(
"getattr",
[](Object& self, const std::string& name) {
return toPyObject(self.attr(name));
})
.def(
"__getattr__",
[](Object& self, const std::string& name) {
if (auto method = self.find_method(name)) {
return py::cast(*method);
}
return toPyObject(self.attr(name));
})
.def(
"hasattr",
[](Object& self, const std::string& name) {
return self.hasattr(name);
})
.def(
"_has_method",
[](Object& self, const std::string& name) {
return bool(self.find_method(name));
})
.def(
"_method_names", [](Object& self) {
return fmap(self.get_methods(), [](const Method& method) {
return method.name();
});
});
// torch.jit.ScriptModule is a subclass of this C++ object. // torch.jit.ScriptModule is a subclass of this C++ object.
// Methods here are prefixed with _ since they should not be // Methods here are prefixed with _ since they should not be
// public. // public.
py::class_<Module>(m, "ScriptModule") py::class_<Module, Object>(m, "ScriptModule")
.def(py::init<std::string, std::shared_ptr<CompilationUnit>, bool>()) .def(py::init<std::string, std::shared_ptr<CompilationUnit>, bool>())
.def( .def(
"save", "save",
@ -598,46 +645,10 @@ void initJitScriptBindings(PyObject* module) {
py::arg("attrs") = true, py::arg("attrs") = true,
py::arg("params") = true, py::arg("params") = true,
py::arg("indent") = 0) py::arg("indent") = 0)
.def(
"_define",
[](Module& m,
std::shared_ptr<ConcreteModuleType> concreteType,
const std::string& script,
ResolutionCallback rcb) {
const auto self = ModuleSelf(std::move(concreteType));
m.class_compilation_unit()->define(
m.name(), script, pythonResolver(rcb), &self);
didFinishEmitModule(m);
})
.def("_type", [](Module& m) { return m.type(); })
.def(
"_get_method",
[](Module& self, const std::string& name) -> Method {
return self.get_method(name);
},
py::keep_alive<0, 1>())
.def(
"setattr",
[](Module& self, const std::string& name, py::object value) {
TypePtr type = self.type()->getAttribute(name);
TORCH_CHECK(type, "Module has no attribute '", name, "'");
auto ivalue = toIValue(std::move(value), type);
self.setattr(name, ivalue);
})
.def(
"getattr",
[](Module& self, const std::string& name) {
return toPyObject(self.attr(name));
})
.def(
"hasattr",
[](Module& self, const std::string& name) {
return self.hasattr(name);
})
.def( .def(
"_replicate_for_data_parallel", "_replicate_for_data_parallel",
[](Module& module) { [](Module& module) {
const ModulePtr& obj = module.module_object(); const ModulePtr& obj = module._ivalue();
auto copy = c10::ivalue::Object::create( auto copy = c10::ivalue::Object::create(
c10::StrongTypePtr(obj->compilation_unit(), obj->type()), c10::StrongTypePtr(obj->compilation_unit(), obj->type()),
obj->slots().size()); obj->slots().size());
@ -647,16 +658,24 @@ void initJitScriptBindings(PyObject* module) {
return Module(std::move(copy)); return Module(std::move(copy));
}) })
.def( .def(
"_has_method", "get_debug_state",
[](Module& self, const std::string& name) { [](Module& self) {
return bool(self.find_method(name)); if (auto m = self.find_method("forward")) {
return m->get_executor().getDebugState();
}
throw std::runtime_error(
"Attempted to call get_debug_state on a Module without a compiled forward()");
}) })
.def( .def(
"_method_names", "_define",
[](Module& self) { [](Module& m,
return fmap(self.get_methods(), [](const Method& method) { std::shared_ptr<ConcreteModuleType> concreteType,
return method.name(); const std::string& script,
}); ResolutionCallback rcb) {
const auto self = ModuleSelf(std::move(concreteType));
m._ivalue()->compilation_unit()->define(
*m.type()->name(), script, pythonResolver(rcb), &self);
didFinishEmitModule(m);
}) })
.def( .def(
"_create_method_from_trace", "_create_method_from_trace",
@ -672,21 +691,12 @@ void initJitScriptBindings(PyObject* module) {
std::shared_ptr<Graph> graph = std::get<0>(tracer::createGraphByTracing( std::shared_ptr<Graph> graph = std::get<0>(tracer::createGraphByTracing(
func, typed_inputs, var_lookup_fn, force_outplace, &self)); func, typed_inputs, var_lookup_fn, force_outplace, &self));
const auto method_name = QualifiedName(self.name(), name); const auto method_name = QualifiedName(*self.type()->name(), name);
auto fn = self.class_compilation_unit()->create_function( auto fn = self._ivalue()->compilation_unit()->create_function(
method_name, graph); method_name, graph);
self.type()->addMethod(fn); self.type()->addMethod(fn);
didFinishEmitModule(self); didFinishEmitModule(self);
}) })
.def(
"get_debug_state",
[](Module& self) {
if (auto m = self.find_method("forward")) {
return m->get_executor().getDebugState();
}
throw std::runtime_error(
"Attempted to call get_debug_state on a Module without a compiled forward()");
})
.def_property_readonly( .def_property_readonly(
"code", "code",
[](Module& self) { [](Module& self) {
@ -697,9 +707,7 @@ void initJitScriptBindings(PyObject* module) {
return pp.str(); return pp.str();
}) })
.def("apply", &Module::apply) .def("apply", &Module::apply)
.def("_clone", &Module::clone) .def("_clone", &Module::clone);
.def_property_readonly(
"name", [](const Module& self) { return self.name().name(); });
slot_dict_impl<script::detail::ParameterPolicy>::bind(m, "ParameterDict"); slot_dict_impl<script::detail::ParameterPolicy>::bind(m, "ParameterDict");
slot_dict_impl<script::detail::BufferPolicy>::bind(m, "BufferDict"); slot_dict_impl<script::detail::BufferPolicy>::bind(m, "BufferDict");

View File

@ -0,0 +1,65 @@
#include <ATen/core/ivalue.h>
#include <ATen/core/stack.h>
#include <torch/csrc/jit/function.h>
namespace torch {
namespace jit {
namespace script {
using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>;
// A method in a module, e.g. f in:
//
// class M(ScriptModule):
// @script_method
// def f(self, x):
// ...
// Note: because Method/Module are exposed to python these
// classes use python method naming conventions
struct TORCH_API Method {
Method(ObjectPtr owner, Function* function);
// the module that contains this method.
Module owner() const;
void run(Stack& stack);
void run(Stack&& stack) {
run(stack);
}
c10::IValue operator()(
std::vector<c10::IValue> stack,
const Kwargs& kwargs = Kwargs());
std::shared_ptr<Graph> graph() const {
return function_->graph();
}
const std::string& name() const {
return function_->name();
}
size_t num_inputs() const {
return function_->num_inputs();
}
GraphExecutor& get_executor() {
return function_->get_executor();
}
Function& function() const {
return *function_;
}
private:
// Methods are uniqued onwed by a single module. This raw pointer allows
// looking up the module.
ObjectPtr owner_;
// Underlying unbound function
Function* function_;
};
} // namespace script
} // namespace jit
} // namespace torch

View File

@ -14,7 +14,7 @@ namespace torch {
namespace jit { namespace jit {
namespace script { namespace script {
static ModulePtr create_module_object( static ObjectPtr create_module_object(
c10::QualifiedName class_name, c10::QualifiedName class_name,
std::shared_ptr<CompilationUnit> cu, std::shared_ptr<CompilationUnit> cu,
bool shouldMangle = false) { bool shouldMangle = false) {
@ -33,14 +33,14 @@ static ModulePtr create_module_object(
} }
Module::Module(c10::QualifiedName class_name) Module::Module(c10::QualifiedName class_name)
: module_value_(create_module_object( : Object(create_module_object(
std::move(class_name), std::move(class_name),
std::make_shared<CompilationUnit>())) {} std::make_shared<CompilationUnit>())) {}
Module::Module( Module::Module(
std::shared_ptr<CompilationUnit> cu, std::shared_ptr<CompilationUnit> cu,
const c10::ClassTypePtr& type) const c10::ClassTypePtr& type)
: module_value_(c10::ivalue::Object::create( : Object(c10::ivalue::Object::create(
c10::StrongTypePtr(std::move(cu), type), c10::StrongTypePtr(std::move(cu), type),
type->numAttributes())) {} type->numAttributes())) {}
@ -48,21 +48,11 @@ Module::Module(
c10::QualifiedName class_name, c10::QualifiedName class_name,
std::shared_ptr<CompilationUnit> cu, std::shared_ptr<CompilationUnit> cu,
bool shouldMangle) bool shouldMangle)
: module_value_(create_module_object( : Object(create_module_object(
std::move(class_name), std::move(class_name),
std::move(cu), std::move(cu),
shouldMangle)) {} shouldMangle)) {}
ModulePtr Module::module_object() const {
if (!module_value_) {
// User has created a Model without assigning it to something already
// loaded. This is done in tests, and when using the .define method.
module_value_ =
create_module_object("Module", std::make_shared<CompilationUnit>());
}
return module_value_;
}
// first class mode runs models as first class objects, // first class mode runs models as first class objects,
// and does not force inlining everywhere. This is experimental // and does not force inlining everywhere. This is experimental
// as we bring up the system since it will degrade performance // as we bring up the system since it will degrade performance
@ -155,21 +145,15 @@ Module Method::owner() const {
return Module(owner_); return Module(owner_);
} }
void Method::run(Stack& stack) { void Method::run(Stack& stack) {
stack.insert(stack.begin(), owner().module_object()); stack.insert(stack.begin(), owner()._ivalue());
function_->run(stack); function_->run(stack);
} }
IValue Method::operator()(std::vector<IValue> stack, const Kwargs& kwargs) { IValue Method::operator()(std::vector<IValue> stack, const Kwargs& kwargs) {
stack.insert(stack.begin(), owner().module_object()); stack.insert(stack.begin(), owner()._ivalue());
return (*function_)(std::move(stack), kwargs); return (*function_)(std::move(stack), kwargs);
} }
void Module::define(const std::string& src, const ResolverPtr& resolver) {
const auto self = SimpleSelf(type());
class_compilation_unit()->define(
name(), src, resolver ? resolver : script::nativeResolver(), &self);
}
void Module::clone_method( void Module::clone_method(
const Module& orig, const Module& orig,
const Function& method, const Function& method,
@ -195,7 +179,7 @@ void Module::clone_method(
auto schema = method.getSchema().cloneWithRemappedTypes(type_remap_fn); auto schema = method.getSchema().cloneWithRemappedTypes(type_remap_fn);
const auto this_method_name = getNameForMethod(method.name()); const auto this_method_name = getNameForMethod(method.name());
auto copied = auto copied =
class_compilation_unit()->create_function(this_method_name, graph); _ivalue()->compilation_unit()->create_function(this_method_name, graph);
type()->addMethod(copied); type()->addMethod(copied);
copied->setSchema(std::move(schema)); copied->setSchema(std::move(schema));
} }
@ -206,8 +190,7 @@ void Module::clone_method(const Module& orig, const std::string& name) {
while (!to_scan.empty()) { while (!to_scan.empty()) {
auto entry = to_scan.back(); auto entry = to_scan.back();
to_scan.pop_back(); to_scan.pop_back();
type_remap[entry.first.module_object()->type()] = type_remap[entry.first._ivalue()->type()] = entry.second._ivalue()->type();
entry.second.module_object()->type();
for (const NameModule& s : entry.first.named_children()) { for (const NameModule& s : entry.first.named_children()) {
to_scan.emplace_back( to_scan.emplace_back(
s.value, Module(entry.second.attr(s.name).toObject())); s.value, Module(entry.second.attr(s.name).toObject()));
@ -223,16 +206,16 @@ Module Module::clone() const {
Module Module::clone_impl( Module Module::clone_impl(
std::unordered_map<TypePtr, TypePtr>& type_remap) const { std::unordered_map<TypePtr, TypePtr>& type_remap) const {
// Create a new module_object in the same compilation unit. // Create a new _ivalue in the same compilation unit.
// The name is the same as for the original module, but it'll be mangled. // The name is the same as for the original module, but it'll be mangled.
// The class type is also created from scratch. // The class type is also created from scratch.
Module r(name(), class_compilation_unit(), true); Module r(*type()->name(), _ivalue()->compilation_unit(), true);
type_remap[type()] = r.type(); type_remap[type()] = r.type();
// Copy slots. If a slot is a module - recursively clone it. // Copy slots. If a slot is a module - recursively clone it.
size_t N = type()->numAttributes(); size_t N = type()->numAttributes();
for (size_t i = 0; i < N; ++i) { for (size_t i = 0; i < N; ++i) {
IValue s = module_object()->getSlot(i); IValue s = _ivalue()->getSlot(i);
if (type()->getAttribute(i)->is_module()) { if (type()->getAttribute(i)->is_module()) {
const Module& orig = Module(s.toObject()); const Module& orig = Module(s.toObject());
Module cloned = orig.clone_impl(type_remap); Module cloned = orig.clone_impl(type_remap);
@ -256,8 +239,8 @@ Module Module::clone_impl(
void Module::train(bool on) { void Module::train(bool on) {
for (Module m : modules()) { for (Module m : modules()) {
if (auto slot = m.module_object()->type()->findAttributeSlot("training")) { if (auto slot = m._ivalue()->type()->findAttributeSlot("training")) {
m.module_object()->setSlot(*slot, on); m._ivalue()->setSlot(*slot, on);
} else { } else {
TORCH_INTERNAL_ASSERT("'training' attribute not found"); TORCH_INTERNAL_ASSERT("'training' attribute not found");
} }
@ -267,7 +250,7 @@ void Module::train(bool on) {
IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const { IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const {
// Look up the class // Look up the class
const auto classType = const auto classType =
class_compilation_unit()->get_class(c10::QualifiedName(name)); _ivalue()->compilation_unit()->get_class(c10::QualifiedName(name));
if (!classType) { if (!classType) {
AT_ERROR( AT_ERROR(
"Could not find class with name: '", "Could not find class with name: '",
@ -278,7 +261,7 @@ IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const {
// Create a bare object with correct number of slots // Create a bare object with correct number of slots
const size_t numAttrs = classType->numAttributes(); const size_t numAttrs = classType->numAttributes();
auto obj = c10::ivalue::Object::create( auto obj = c10::ivalue::Object::create(
c10::StrongTypePtr(class_compilation_unit(), classType), numAttrs); c10::StrongTypePtr(_ivalue()->compilation_unit(), classType), numAttrs);
// Invoke the `__init__()` of the class with the arguments provided. // Invoke the `__init__()` of the class with the arguments provided.
Stack stackWithSelf = {obj}; Stack stackWithSelf = {obj};
@ -319,15 +302,6 @@ named_parameter_list Module::named_parameters(bool recurse) const {
return named_parameter_list(*this, recurse, /*return_module=*/false); return named_parameter_list(*this, recurse, /*return_module=*/false);
} }
c10::optional<Method> Module::find_method(const std::string& basename) const {
for (Function* fn : type()->methods()) {
if (fn->name() == basename) {
return Method(module_object(), fn);
}
}
return c10::nullopt;
}
attribute_list Module::attributes(bool recurse) const { attribute_list Module::attributes(bool recurse) const {
return attribute_list(*this, recurse, /*return_module=*/false); return attribute_list(*this, recurse, /*return_module=*/false);
} }
@ -380,7 +354,7 @@ std::string Module::dump_to_str(
methods_ss << " }" << std::endl; methods_ss << " }" << std::endl;
} }
ss << "module " << name().qualifiedName() << " {" << std::endl; ss << "module " << type()->name()->qualifiedName() << " {" << std::endl;
ss << " parameters {" << std::endl; ss << " parameters {" << std::endl;
ss << torch::jit::jit_log_prefix(" ", parameters_ss.str()); ss << torch::jit::jit_log_prefix(" ", parameters_ss.str());
ss << " }" << std::endl; ss << " }" << std::endl;

View File

@ -6,6 +6,7 @@
#include <torch/csrc/jit/ir.h> #include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/named_value.h> #include <torch/csrc/jit/named_value.h>
#include <torch/csrc/jit/passes/shape_analysis.h> #include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/script/object.h>
#include <torch/csrc/jit/source_range.h> #include <torch/csrc/jit/source_range.h>
#include <torch/csrc/WindowsTorchApiMacro.h> #include <torch/csrc/WindowsTorchApiMacro.h>
@ -84,73 +85,17 @@ using named_buffer_list =
using ModuleLookup = std::function<Module(const std::vector<std::string>&)>; using ModuleLookup = std::function<Module(const std::vector<std::string>&)>;
// A method in a module, e.g. f in: struct TORCH_API Module : public Object {
//
// class M(ScriptModule):
// @script_method
// def f(self, x):
// ...
// Note: because Method/Module are exposed to python these
// classes use python method naming conventions
struct TORCH_API Method {
Method(ModulePtr owner, Function* function);
// the module that contains this method.
Module owner() const;
void run(Stack& stack);
void run(Stack&& stack) {
run(stack);
}
IValue operator()(std::vector<IValue> stack, const Kwargs& kwargs = Kwargs());
std::shared_ptr<Graph> graph() const {
return function_->graph();
}
const std::string& name() const {
return function_->name();
}
size_t num_inputs() const {
return function_->num_inputs();
}
GraphExecutor& get_executor() {
return function_->get_executor();
}
Function& function() const {
return *function_;
}
private:
// Methods are uniqued onwed by a single module. This raw pointer allows
// looking up the module.
ModulePtr owner_;
// Underlying unbound function
// This is the _lowered_ function and is different than the
// first-class function in class_compilation_unit()
Function* function_;
};
struct TORCH_API Module {
explicit Module(c10::QualifiedName class_name); explicit Module(c10::QualifiedName class_name);
Module(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type); Module(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type);
Module() {}
Module( Module(
c10::QualifiedName, c10::QualifiedName,
std::shared_ptr<CompilationUnit> cu, std::shared_ptr<CompilationUnit> cu,
bool shouldMangle = false); bool shouldMangle = false);
// module_value_ null and will be lazily initialized if is needed Module(ModulePtr module_value) : Object(std::move(module_value)) {}
Module() {}
Module(ModulePtr module_value) : module_value_(std::move(module_value)) {}
~Module() {} ~Module() {}
const c10::QualifiedName& name() const {
return *module_object()->type()->name();
}
void set_optimized(bool o) { void set_optimized(bool o) {
AT_WARN( AT_WARN(
"Module::set_optimized() is deprecated and has no effect. " "Module::set_optimized() is deprecated and has no effect. "
@ -174,14 +119,14 @@ struct TORCH_API Module {
// whether a slot is a parameter to be able to classify it. // whether a slot is a parameter to be able to classify it.
void register_buffer(const std::string& name, at::Tensor v) { void register_buffer(const std::string& name, at::Tensor v) {
type()->addOrCheckAttribute(name, TensorType::get()); type()->addOrCheckAttribute(name, TensorType::get());
module_object()->setAttr(name, std::move(v)); _ivalue()->setAttr(name, std::move(v));
} }
void register_parameter( void register_parameter(
const std::string& name, const std::string& name,
at::Tensor v, at::Tensor v,
bool is_buffer) { bool is_buffer) {
type()->addOrCheckAttribute(name, TensorType::get(), !is_buffer); type()->addOrCheckAttribute(name, TensorType::get(), !is_buffer);
module_object()->setAttr(name, std::move(v)); _ivalue()->setAttr(name, std::move(v));
} }
void register_attribute( void register_attribute(
const std::string& name, const std::string& name,
@ -189,51 +134,11 @@ struct TORCH_API Module {
IValue v, IValue v,
bool is_param = false) { bool is_param = false) {
type()->addOrCheckAttribute(name, t, is_param); type()->addOrCheckAttribute(name, t, is_param);
module_object()->setAttr(name, std::move(v)); _ivalue()->setAttr(name, std::move(v));
} }
void register_module(const std::string& name, const Module& module) { void register_module(const std::string& name, const Module& module) {
type()->addOrCheckAttribute(name, module.type()); type()->addOrCheckAttribute(name, module.type());
module_object()->setAttr(name, module.module_object()); _ivalue()->setAttr(name, module._ivalue());
}
void setattr(const std::string& name, IValue v) {
size_t slot = module_object()->type()->getAttributeSlot(name);
const TypePtr& expected = module_object()->type()->getAttribute(slot);
TORCH_CHECK(expected, "Module has no attribute '", name, "'");
TORCH_CHECK(
v.type()->isSubtypeOf(expected),
"Expected a value of type '",
expected->python_str(),
"' for field '",
name,
"', but found '",
v.type()->python_str(),
"'");
module_object()->setSlot(slot, std::move(v));
}
IValue attr(const std::string& name) const {
return module_object()->getAttr(name);
}
IValue attr(const std::string& name, IValue or_else) const {
if (auto r = module_object()->type()->findAttributeSlot(name)) {
return module_object()->getSlot(*r);
}
return or_else;
}
bool hasattr(const std::string& name) const {
return module_object()->type()->findAttributeSlot(name).has_value();
}
// each module owns its method. The reference returned here
// is guarenteed to stay valid until this module has been destroyed
Method get_method(const std::string& name) const {
if (auto method = find_method(name)) {
return *method;
}
AT_ERROR("Method '", name, "' is not defined.");
} }
void apply(const std::function<void(Module&)>& fn); void apply(const std::function<void(Module&)>& fn);
@ -265,16 +170,6 @@ struct TORCH_API Module {
bool print_param_values, bool print_param_values,
int level) const; int level) const;
const std::vector<Method> get_methods() const {
return fmap(
type()->methods(),
[&](Function* func) {
return Method(module_object(), func);
});
}
c10::optional<Method> find_method(const std::string& basename) const;
/// Enables "training" mode. /// Enables "training" mode.
void train(bool on = true); void train(bool on = true);
/// Calls train(false) to enable "eval" mode. /// Calls train(false) to enable "eval" mode.
@ -311,24 +206,6 @@ struct TORCH_API Module {
/// effect. /// effect.
void to(at::Device device, bool non_blocking = false); void to(at::Device device, bool non_blocking = false);
/// Run a method from this module.
///
/// For example:
/// @code
/// IValue output = module->run("relu_script", a, b);
/// @endcode
///
/// To get a compile a module from a source string, see torch::jit::compile
///
/// @param method_name The name of the method to run
/// @param args Arguments to be passed to the method
/// @return An IValue containing the return value (or values if it is a tuple)
/// from the method
template <typename... Types>
IValue run_method(const std::string& method_name, Types&&... args) {
return get_method(method_name)({IValue(std::forward<Types>(args))...});
}
void save( void save(
std::ostream& out, std::ostream& out,
const ExtraFilesMap& extra_files = ExtraFilesMap()) const; const ExtraFilesMap& extra_files = ExtraFilesMap()) const;
@ -350,18 +227,6 @@ struct TORCH_API Module {
void clone_method(const Module& orig, const std::string& name); void clone_method(const Module& orig, const std::string& name);
ModulePtr module_object() const;
ClassTypePtr type() const {
return module_object()->type();
}
std::shared_ptr<CompilationUnit> class_compilation_unit() const {
return module_object()->compilation_unit();
}
// so that C++ users can easily add methods
void define(const std::string& src, const ResolverPtr& resolver = nullptr);
template <typename... Types> template <typename... Types>
IValue create_class(const c10::QualifiedName& name, Types&&... args) const { IValue create_class(const c10::QualifiedName& name, Types&&... args) const {
return create_class(name, {IValue(std::forward<Types>(args))...}); return create_class(name, {IValue(std::forward<Types>(args))...});
@ -369,10 +234,6 @@ struct TORCH_API Module {
IValue create_class(const c10::QualifiedName& name, Stack stack) const; IValue create_class(const c10::QualifiedName& name, Stack stack) const;
size_t num_slots() const {
return module_object()->slots().size();
}
private: private:
Module clone_impl(std::unordered_map<TypePtr, TypePtr>& type_remap) const; Module clone_impl(std::unordered_map<TypePtr, TypePtr>& type_remap) const;
@ -382,16 +243,13 @@ struct TORCH_API Module {
const std::unordered_map<TypePtr, TypePtr>& type_remap); const std::unordered_map<TypePtr, TypePtr>& type_remap);
c10::QualifiedName getNameForMethod(std::string basename) const { c10::QualifiedName getNameForMethod(std::string basename) const {
return QualifiedName(name(), basename); return QualifiedName(*type()->name(), basename);
} }
void to_impl( void to_impl(
const c10::optional<at::Device>& device, const c10::optional<at::Device>& device,
const c10::optional<at::ScalarType>& dtype, const c10::optional<at::ScalarType>& dtype,
bool non_blocking); bool non_blocking);
// mutable be we lazily initialize in module_object.
mutable ModulePtr module_value_;
}; };
namespace detail { namespace detail {
@ -460,8 +318,8 @@ struct slot_iterator_impl {
return cursors_.back(); return cursors_.back();
} }
IValue cur() const { IValue cur() const {
return return_module() ? top().module_.module_object() return return_module() ? top().module_._ivalue()
: top().module_.module_object()->getSlot(top().i_); : top().module_._ivalue()->getSlot(top().i_);
} }
// advance to the next slot in a depth first pre-order traversal of the // advance to the next slot in a depth first pre-order traversal of the
@ -487,11 +345,7 @@ struct slot_iterator_impl {
// if the current thing is a module, we have to scan it for recursive // if the current thing is a module, we have to scan it for recursive
// traversals. We do this by adding a new SlotCursor to track the traversal. // traversals. We do this by adding a new SlotCursor to track the traversal.
if (recurse_ && if (recurse_ &&
top() top().module_._ivalue()->type()->getAttribute(top().i_)->is_module()) {
.module_.module_object()
->type()
->getAttribute(top().i_)
->is_module()) {
cursors_.emplace_back(SlotCursor{cur().toModule(), 0}); cursors_.emplace_back(SlotCursor{cur().toModule(), 0});
return; return;
} }
@ -502,7 +356,7 @@ struct slot_iterator_impl {
// otherwise, we have to continue advancing. // otherwise, we have to continue advancing.
bool valid() const { bool valid() const {
return top().i_ < int64_t(top().module_.num_slots()) && return top().i_ < int64_t(top().module_.num_slots()) &&
Policy::valid(top().module_.module_object()->type(), top().i_); Policy::valid(top().module_._ivalue()->type(), top().i_);
} }
void while_not_valid_next() { void while_not_valid_next() {
// advance iteration until we are either at the end (cursors_.empty()) // advance iteration until we are either at the end (cursors_.empty())

View File

@ -0,0 +1,44 @@
#include <torch/csrc/jit/script/object.h>
#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/script/compilation_unit.h>
#include <torch/csrc/jit/script/resolver.h>
#include <torch/csrc/jit/script/sugared_value.h>
namespace torch {
namespace jit {
namespace script {
Object::Object(
std::shared_ptr<CompilationUnit> cu,
const c10::ClassTypePtr& type)
: Object(c10::ivalue::Object::create(
c10::StrongTypePtr(std::move(cu), type),
type->numAttributes())) {}
ObjectPtr Object::_ivalue() const {
TORCH_INTERNAL_ASSERT(_ivalue_);
return _ivalue_;
}
c10::optional<Method> Object::find_method(const std::string& basename) const {
for (Function* fn : type()->methods()) {
if (fn->name() == basename) {
return Method(_ivalue(), fn);
}
}
return c10::nullopt;
}
void Object::define(const std::string& src, const ResolverPtr& resolver) {
const auto self = SimpleSelf(type());
_ivalue()->compilation_unit()->define(
*type()->name(),
src,
resolver ? resolver : script::nativeResolver(),
&self);
}
} // namespace script
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,110 @@
#pragma once
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/script/method.h>
namespace torch {
namespace jit {
namespace script {
struct Resolver;
using ResolverPtr = std::shared_ptr<Resolver>;
using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>;
struct TORCH_API Object {
Object() {}
Object(ObjectPtr _ivalue) : _ivalue_(std::move(_ivalue)) {}
Object(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type);
Object(
c10::QualifiedName,
std::shared_ptr<CompilationUnit> cu,
bool shouldMangle = false);
ObjectPtr _ivalue() const;
c10::ClassTypePtr type() const {
return _ivalue()->type();
}
void setattr(const std::string& name, c10::IValue v) {
size_t slot = _ivalue()->type()->getAttributeSlot(name);
const c10::TypePtr& expected = _ivalue()->type()->getAttribute(slot);
TORCH_CHECK(expected, "Module has no attribute '", name, "'");
TORCH_CHECK(
v.type()->isSubtypeOf(expected),
"Expected a value of type '",
expected->python_str(),
"' for field '",
name,
"', but found '",
v.type()->python_str(),
"'");
_ivalue()->setSlot(slot, std::move(v));
}
c10::IValue attr(const std::string& name) const {
return _ivalue()->getAttr(name);
}
c10::IValue attr(const std::string& name, c10::IValue or_else) const {
if (auto r = _ivalue()->type()->findAttributeSlot(name)) {
return _ivalue()->getSlot(*r);
}
return or_else;
}
bool hasattr(const std::string& name) const {
return _ivalue()->type()->findAttributeSlot(name).has_value();
}
// each object owns its methods. The reference returned here
// is guarenteed to stay valid until this module has been destroyed
Method get_method(const std::string& name) const {
if (auto method = find_method(name)) {
return *method;
}
AT_ERROR("Method '", name, "' is not defined.");
}
const std::vector<Method> get_methods() const {
return fmap(type()->methods(), [&](Function* func) {
return Method(_ivalue(), func);
});
}
c10::optional<Method> find_method(const std::string& basename) const;
/// Run a method from this module.
///
/// For example:
/// @code
/// IValue output = module->run("relu_script", a, b);
/// @endcode
///
/// To get a compile a module from a source string, see torch::jit::compile
///
/// @param method_name The name of the method to run
/// @param args Arguments to be passed to the method
/// @return An IValue containing the return value (or values if it is a tuple)
/// from the method
template <typename... Types>
IValue run_method(const std::string& method_name, Types&&... args) {
return get_method(method_name)({IValue(std::forward<Types>(args))...});
}
// so that C++ users can easily add methods
void define(const std::string& src, const ResolverPtr& resolver = nullptr);
size_t num_slots() const {
return _ivalue()->slots().size();
}
private:
// mutable be we lazily initialize in module_object.
mutable ObjectPtr _ivalue_;
};
} // namespace script
} // namespace jit
} // namespace torch

View File

@ -290,7 +290,7 @@ static void gatherParametersAndBuffers(
const std::string& prefix) { const std::string& prefix) {
Graph& g = *self_value->owningGraph(); Graph& g = *self_value->owningGraph();
state->setValue(self.module_object(), self_value); state->setValue(self._ivalue(), self_value);
auto self_ty = self.type(); auto self_ty = self.type();
for (const script::NameValue& s : self.named_attributes(/*recurse=*/false)) { for (const script::NameValue& s : self.named_attributes(/*recurse=*/false)) {
@ -328,8 +328,8 @@ std::pair<std::shared_ptr<TracingState>, Stack> trace(
// if we are a module, then make sure the modules parameters are in the map // if we are a module, then make sure the modules parameters are in the map
// and mapped to accesses to the self object // and mapped to accesses to the self object
if (self) { if (self) {
Value* self_value = Value* self_value = state->graph->insertInput(0, "self")->setType(
state->graph->insertInput(0, "self")->setType(self->module_object()->type()); self->_ivalue()->type());
gatherParametersAndBuffers(state, self_value, *self, {"__module"}); gatherParametersAndBuffers(state, self_value, *self, {"__module"});
} }

View File

@ -75,7 +75,7 @@ ScriptModuleOutput ScriptModuleBenchmark::runOnce(
function.getSchema(), function.getSchema(),
std::move(args), std::move(args),
std::move(kwargs), std::move(kwargs),
model_.module_object()); model_._ivalue());
return function(std::move(stack)); return function(std::move(stack));
} }
@ -100,7 +100,7 @@ void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs) {
model_.get_method("forward").function().getSchema(), model_.get_method("forward").function().getSchema(),
std::move(args), std::move(args),
std::move(kwargs), std::move(kwargs),
model_.module_object()); model_._ivalue());
inputs_.emplace_back(std::move(stack)); inputs_.emplace_back(std::move(stack));
} }

View File

@ -59,7 +59,7 @@ namespace detail {
template <class Input, class Output, class Model> template <class Input, class Output, class Model>
class BenchmarkHelper { class BenchmarkHelper {
public: public:
BenchmarkHelper(): initialized_{false} {} BenchmarkHelper();
explicit BenchmarkHelper(Model model): model_(model), initialized_(true) {} explicit BenchmarkHelper(Model model): model_(model), initialized_(true) {}
// This method to be used in benchmark() method // This method to be used in benchmark() method
@ -108,7 +108,14 @@ typedef BenchmarkHelper<
at::IValue, at::IValue,
jit::script::Module> jit::script::Module>
ScriptModuleBenchmark; ScriptModuleBenchmark;
template <>
inline BenchmarkHelper<ScriptModuleInput, at::IValue, jit::script::Module>::BenchmarkHelper()
: model_("Module", std::make_shared<jit::script::CompilationUnit>()),
initialized_(false) {}
typedef BenchmarkHelper<ModuleInput, py::object, py::object> ModuleBenchmark; typedef BenchmarkHelper<ModuleInput, py::object, py::object> ModuleBenchmark;
template <>
inline BenchmarkHelper<ModuleInput, py::object, py::object>::BenchmarkHelper()
: initialized_(false) {}
template <> template <>
void ScriptModuleBenchmark::runOnce( void ScriptModuleBenchmark::runOnce(

View File

@ -1627,9 +1627,9 @@ if _enabled:
@property @property
def original_name(self): def original_name(self):
if type(self) == self._c.name: if type(self) == str(self._c._type().name()):
return '' return ''
return self._c.name return str(self._c._type().name())
def define(self, src): def define(self, src):
# We use frames_up=1 to get to the proper surrounding scope. The stack # We use frames_up=1 to get to the proper surrounding scope. The stack