mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Factor Module into Object and Module
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29500 Test Plan: Imported from OSS Differential Revision: D18463064 Pulled By: jamesr66a fbshipit-source-id: d37bef242a8626593d4b8754042152cfc0f0acb2
This commit is contained in:
parent
14946a8891
commit
18bdf97dbb
|
|
@ -44,7 +44,7 @@ void dump_opnames(const script::Module& m, std::unordered_set<std::string>& opna
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (const auto& sub_m : m.children()) {
|
for (const auto& sub_m : m.children()) {
|
||||||
std::cout << "sub module name: " << sub_m.name().qualifiedName() << std::endl;
|
std::cout << "sub module name: " << sub_m.type()->name()->qualifiedName() << std::endl;
|
||||||
dump_opnames(sub_m, opnames);
|
dump_opnames(sub_m, opnames);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -444,6 +444,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||||
${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp
|
${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp
|
${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
|
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
|
||||||
|
${TORCH_SRC_DIR}/csrc/jit/script/object.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/script/jit_exception.cpp
|
${TORCH_SRC_DIR}/csrc/jit/script/jit_exception.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/source_range_serialization.cpp
|
${TORCH_SRC_DIR}/csrc/jit/source_range_serialization.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
|
${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
|
||||||
|
|
|
||||||
|
|
@ -89,12 +89,12 @@ void testScriptObject() {
|
||||||
Module m2("m2");
|
Module m2("m2");
|
||||||
std::vector<at::Tensor> constantTable;
|
std::vector<at::Tensor> constantTable;
|
||||||
import_libs(
|
import_libs(
|
||||||
m1.class_compilation_unit(),
|
m1._ivalue()->compilation_unit(),
|
||||||
"__torch__.FooTest",
|
"__torch__.FooTest",
|
||||||
std::make_shared<Source>(classSrcs1),
|
std::make_shared<Source>(classSrcs1),
|
||||||
constantTable);
|
constantTable);
|
||||||
import_libs(
|
import_libs(
|
||||||
m2.class_compilation_unit(),
|
m2._ivalue()->compilation_unit(),
|
||||||
"__torch__.FooTest",
|
"__torch__.FooTest",
|
||||||
std::make_shared<Source>(classSrcs2),
|
std::make_shared<Source>(classSrcs2),
|
||||||
constantTable);
|
constantTable);
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ void testModuleInterfaceSerialization() {
|
||||||
parentMod.register_attribute(
|
parentMod.register_attribute(
|
||||||
"subMod",
|
"subMod",
|
||||||
cu->get_interface("__torch__.OneForward"),
|
cu->get_interface("__torch__.OneForward"),
|
||||||
subMod.module_object(),
|
subMod._ivalue(),
|
||||||
/*is_parameter=*/false);
|
/*is_parameter=*/false);
|
||||||
parentMod.define(parentForward, nativeResolver());
|
parentMod.define(parentForward, nativeResolver());
|
||||||
ASSERT_TRUE(parentMod.hasattr("subMod"));
|
ASSERT_TRUE(parentMod.hasattr("subMod"));
|
||||||
|
|
|
||||||
|
|
@ -164,6 +164,7 @@ libtorch_sources = [
|
||||||
"torch/csrc/jit/hooks_for_testing.cpp",
|
"torch/csrc/jit/hooks_for_testing.cpp",
|
||||||
"torch/csrc/jit/script/builtin_functions.cpp",
|
"torch/csrc/jit/script/builtin_functions.cpp",
|
||||||
"torch/csrc/jit/script/module.cpp",
|
"torch/csrc/jit/script/module.cpp",
|
||||||
|
"torch/csrc/jit/script/object.cpp",
|
||||||
"torch/csrc/jit/tracer.cpp",
|
"torch/csrc/jit/tracer.cpp",
|
||||||
"torch/csrc/jit/fuser/kernel_cache.cpp",
|
"torch/csrc/jit/fuser/kernel_cache.cpp",
|
||||||
"torch/csrc/jit/fuser/compiler.cpp",
|
"torch/csrc/jit/fuser/compiler.cpp",
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ namespace serialize {
|
||||||
class TORCH_API OutputArchive final {
|
class TORCH_API OutputArchive final {
|
||||||
public:
|
public:
|
||||||
explicit OutputArchive(std::shared_ptr<jit::script::CompilationUnit> cu);
|
explicit OutputArchive(std::shared_ptr<jit::script::CompilationUnit> cu);
|
||||||
explicit OutputArchive() : cu_(std::make_shared<jit::script::CompilationUnit>()) {}
|
explicit OutputArchive() : cu_(std::make_shared<jit::script::CompilationUnit>()), module_("__torch__.Module", cu_) {}
|
||||||
|
|
||||||
// Move is allowed.
|
// Move is allowed.
|
||||||
OutputArchive(OutputArchive&&) = default;
|
OutputArchive(OutputArchive&&) = default;
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace serialize {
|
namespace serialize {
|
||||||
|
|
||||||
InputArchive::InputArchive() {}
|
InputArchive::InputArchive() : module_("Module", std::make_shared<jit::script::CompilationUnit>()) {}
|
||||||
|
|
||||||
void InputArchive::read(const std::string& key, c10::IValue& ivalue) {
|
void InputArchive::read(const std::string& key, c10::IValue& ivalue) {
|
||||||
ivalue = module_.attr(key);
|
ivalue = module_.attr(key);
|
||||||
|
|
|
||||||
|
|
@ -546,7 +546,7 @@ class ScriptModuleSerializer {
|
||||||
C10_LOG_API_USAGE_ONCE("torch.script.save");
|
C10_LOG_API_USAGE_ONCE("torch.script.save");
|
||||||
writeExtraFiles(module, extra_files);
|
writeExtraFiles(module, extra_files);
|
||||||
// Serialize the model object
|
// Serialize the model object
|
||||||
writeArchive("data", module.module_object());
|
writeArchive("data", module._ivalue());
|
||||||
// Then we werialize all code info.
|
// Then we werialize all code info.
|
||||||
writeCode(module.type());
|
writeCode(module.type());
|
||||||
// The tensor constants from the code are written to a separate archive
|
// The tensor constants from the code are written to a separate archive
|
||||||
|
|
|
||||||
|
|
@ -257,9 +257,9 @@ void ScriptModuleDeserializer::LEGACY_moduleSetState(
|
||||||
// TODO: once modules are first class in the interpreter and methods are not
|
// TODO: once modules are first class in the interpreter and methods are not
|
||||||
// lowered, change this to `module->run_method("__setstate__", {state});`
|
// lowered, change this to `module->run_method("__setstate__", {state});`
|
||||||
if (setstate->num_inputs() == 1) {
|
if (setstate->num_inputs() == 1) {
|
||||||
setstate->run({module.module_object()});
|
setstate->run({module._ivalue()});
|
||||||
} else if (setstate->num_inputs() == 2) {
|
} else if (setstate->num_inputs() == 2) {
|
||||||
setstate->run({module.module_object(), state});
|
setstate->run({module._ivalue(), state});
|
||||||
} else {
|
} else {
|
||||||
AT_ERROR("Unexpected schema on '__setstate__'");
|
AT_ERROR("Unexpected schema on '__setstate__'");
|
||||||
}
|
}
|
||||||
|
|
@ -348,11 +348,11 @@ script::Module ScriptModuleDeserializer::LEGACY_convertModule(
|
||||||
LEGACY_pickled_ivalues_.at(module_def.get_state_attribute_id()));
|
LEGACY_pickled_ivalues_.at(module_def.get_state_attribute_id()));
|
||||||
}
|
}
|
||||||
|
|
||||||
const ClassTypePtr& module_type = module.module_object()->type();
|
const ClassTypePtr& module_type = module._ivalue()->type();
|
||||||
for (size_t i = 0, N = module_type->numAttributes(); i < N; ++i) {
|
for (size_t i = 0, N = module_type->numAttributes(); i < N; ++i) {
|
||||||
// Verify that all the non-optional attributes have been initialized
|
// Verify that all the non-optional attributes have been initialized
|
||||||
// TODO: Issue #20497
|
// TODO: Issue #20497
|
||||||
const IValue& v = module.module_object()->getSlot(i);
|
const IValue& v = module._ivalue()->getSlot(i);
|
||||||
if (module_type->getAttribute(i)->kind() != TypeKind::OptionalType) {
|
if (module_type->getAttribute(i)->kind() != TypeKind::OptionalType) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
!v.isNone(),
|
!v.isNone(),
|
||||||
|
|
|
||||||
|
|
@ -193,7 +193,7 @@ struct SourceImporterImpl : public Resolver,
|
||||||
const script::Module& mod,
|
const script::Module& mod,
|
||||||
const std::shared_ptr<Source>& src) {
|
const std::shared_ptr<Source>& src) {
|
||||||
auto self = SimpleSelf(mod.type());
|
auto self = SimpleSelf(mod.type());
|
||||||
c10::QualifiedName prefix = mod.name();
|
c10::QualifiedName prefix = *mod.type()->name();
|
||||||
Parser p(src);
|
Parser p(src);
|
||||||
|
|
||||||
parsePossibleVersionNumber(p.lexer());
|
parsePossibleVersionNumber(p.lexer());
|
||||||
|
|
|
||||||
|
|
@ -267,7 +267,7 @@ void initJITBindings(PyObject* module) {
|
||||||
.def(
|
.def(
|
||||||
"_jit_pass_lower_graph",
|
"_jit_pass_lower_graph",
|
||||||
[](std::shared_ptr<Graph>& graph, const script::Module& self) {
|
[](std::shared_ptr<Graph>& graph, const script::Module& self) {
|
||||||
return LowerGraph(*graph, self.module_object());
|
return LowerGraph(*graph, self._ivalue());
|
||||||
})
|
})
|
||||||
.def("_jit_pass_loop_unrolling", UnrollLoops)
|
.def("_jit_pass_loop_unrolling", UnrollLoops)
|
||||||
.def(
|
.def(
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ void fillQConfigMap(
|
||||||
} else {
|
} else {
|
||||||
qconfig = parent_qconfig;
|
qconfig = parent_qconfig;
|
||||||
}
|
}
|
||||||
map[module.module_object()] = qconfig;
|
map[module._ivalue()] = qconfig;
|
||||||
|
|
||||||
for (const script::NameModule& s : module.named_children()) {
|
for (const script::NameModule& s : module.named_children()) {
|
||||||
std::string child_key;
|
std::string child_key;
|
||||||
|
|
@ -39,8 +39,7 @@ void fillQConfigMap(
|
||||||
} else {
|
} else {
|
||||||
child_key = key + "." + s.name;
|
child_key = key + "." + s.name;
|
||||||
}
|
}
|
||||||
fillQConfigMap(
|
fillQConfigMap(s.value._ivalue(), qconfig_dict, map, child_key, qconfig);
|
||||||
s.value.module_object(), qconfig_dict, map, child_key, qconfig);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -184,7 +183,8 @@ Node* InsertObserversHelper::insertObserverFor(
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
script::Module observer_module;
|
script::Module observer_module(
|
||||||
|
"Module", std::make_shared<script::CompilationUnit>());
|
||||||
if (isWeightOfConvOrLinear(v)) {
|
if (isWeightOfConvOrLinear(v)) {
|
||||||
TORCH_CHECK(v->uses().size() == 1, "We only support weight being used by one node.");
|
TORCH_CHECK(v->uses().size() == 1, "We only support weight being used by one node.");
|
||||||
observer_module = std::get<1>(qconfig);
|
observer_module = std::get<1>(qconfig);
|
||||||
|
|
@ -275,7 +275,7 @@ graph(%input, %weight, %bias, %4):
|
||||||
void InsertObserversHelper::insertObservers(
|
void InsertObserversHelper::insertObservers(
|
||||||
script::Module& module,
|
script::Module& module,
|
||||||
const std::string& method_name) {
|
const std::string& method_name) {
|
||||||
if (!module_qconfig_map_.count(module.module_object())) {
|
if (!module_qconfig_map_.count(module._ivalue())) {
|
||||||
// the module is added by us, e.g.: observer module
|
// the module is added by us, e.g.: observer module
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
@ -304,7 +304,7 @@ void InsertObserversHelper::insertObservers(
|
||||||
for (size_t idx = 1; idx < method.num_inputs(); ++idx) {
|
for (size_t idx = 1; idx < method.num_inputs(); ++idx) {
|
||||||
auto& v = graph->inputs()[idx];
|
auto& v = graph->inputs()[idx];
|
||||||
if (!values_to_skip_.count(v) && valueNeedsToBeQuantized(v)) {
|
if (!values_to_skip_.count(v) && valueNeedsToBeQuantized(v)) {
|
||||||
auto qconfig = module_qconfig_map_.at(module.module_object());
|
auto qconfig = module_qconfig_map_.at(module._ivalue());
|
||||||
if (qconfig) {
|
if (qconfig) {
|
||||||
auto observer_node =
|
auto observer_node =
|
||||||
insertObserverFor(v, v->owningGraph(), module, qconfig.value());
|
insertObserverFor(v, v->owningGraph(), module, qconfig.value());
|
||||||
|
|
@ -339,7 +339,8 @@ void InsertObserversHelper::insertObservers(
|
||||||
// the child module.
|
// the child module.
|
||||||
auto module_instance = n->inputs()[0];
|
auto module_instance = n->inputs()[0];
|
||||||
auto module_method_name = n->s(attr::name);
|
auto module_method_name = n->s(attr::name);
|
||||||
script::Module callee_module;
|
script::Module callee_module(
|
||||||
|
"Module", std::make_shared<script::CompilationUnit>());
|
||||||
if (module_instance->node()->kind() == prim::GetAttr) {
|
if (module_instance->node()->kind() == prim::GetAttr) {
|
||||||
auto child_module_name = module_instance->node()->s(attr::name);
|
auto child_module_name = module_instance->node()->s(attr::name);
|
||||||
callee_module = module.attr(child_module_name).toModule();
|
callee_module = module.attr(child_module_name).toModule();
|
||||||
|
|
@ -365,7 +366,7 @@ void InsertObserversHelper::insertObservers(
|
||||||
|
|
||||||
// Actually add observer nodes.
|
// Actually add observer nodes.
|
||||||
for (Value* v : values_to_observe) {
|
for (Value* v : values_to_observe) {
|
||||||
auto qconfig = module_qconfig_map_.at(module.module_object());
|
auto qconfig = module_qconfig_map_.at(module._ivalue());
|
||||||
// Skip inserting observer if no qconfig is specified
|
// Skip inserting observer if no qconfig is specified
|
||||||
if (qconfig) {
|
if (qconfig) {
|
||||||
insertObserverFor(v, v->owningGraph(), module, qconfig.value());
|
insertObserverFor(v, v->owningGraph(), module, qconfig.value());
|
||||||
|
|
@ -455,7 +456,7 @@ class QuantizeHelper {
|
||||||
// O(N) where N is number of observer moduels with this optimization
|
// O(N) where N is number of observer moduels with this optimization
|
||||||
for (int64_t i = observer_modules_to_remove_.size() - 1; i >= 0; --i) {
|
for (int64_t i = observer_modules_to_remove_.size() - 1; i >= 0; --i) {
|
||||||
auto observer_name = observer_modules_to_remove_[i];
|
auto observer_name = observer_modules_to_remove_[i];
|
||||||
module_.module_object()->unsafeRemoveAttr(observer_name);
|
module_._ivalue()->unsafeRemoveAttr(observer_name);
|
||||||
module_.type()->unsafeRemoveAttribute(observer_name);
|
module_.type()->unsafeRemoveAttribute(observer_name);
|
||||||
}
|
}
|
||||||
// Destroy observer forward calls
|
// Destroy observer forward calls
|
||||||
|
|
@ -826,7 +827,8 @@ graph(%self, %x):
|
||||||
|
|
||||||
script::Method method = current.get_method("forward");
|
script::Method method = current.get_method("forward");
|
||||||
GRAPH_DUMP(
|
GRAPH_DUMP(
|
||||||
current.name().name() + "::forward() before Conv2d-BatchNorm2d folding",
|
current.type()->name()->name() +
|
||||||
|
"::forward() before Conv2d-BatchNorm2d folding",
|
||||||
method.graph());
|
method.graph());
|
||||||
const auto& matches = findPatternMatches(pattern_graph, *method.graph());
|
const auto& matches = findPatternMatches(pattern_graph, *method.graph());
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -458,7 +458,7 @@ inline IValue toIValue(
|
||||||
auto classType = type->expect<ClassType>();
|
auto classType = type->expect<ClassType>();
|
||||||
if (auto mod = script::as_module(py::cast<py::object>(obj))) {
|
if (auto mod = script::as_module(py::cast<py::object>(obj))) {
|
||||||
// if obj is already a ScriptModule, just return its ivalue
|
// if obj is already a ScriptModule, just return its ivalue
|
||||||
return mod.value().module_object();
|
return mod.value()._ivalue();
|
||||||
}
|
}
|
||||||
// otherwise is a normal class object, we create a fresh
|
// otherwise is a normal class object, we create a fresh
|
||||||
// ivalue::Object to use from the py object.
|
// ivalue::Object to use from the py object.
|
||||||
|
|
@ -487,7 +487,7 @@ inline IValue toIValue(
|
||||||
IValue res;
|
IValue res;
|
||||||
if (auto mod = script::as_module(py::cast<py::object>(obj))) {
|
if (auto mod = script::as_module(py::cast<py::object>(obj))) {
|
||||||
classType = mod.value().type();
|
classType = mod.value().type();
|
||||||
res = mod.value().module_object();
|
res = mod.value()._ivalue();
|
||||||
} else {
|
} else {
|
||||||
// We inspect the value to found the compiled TorchScript class
|
// We inspect the value to found the compiled TorchScript class
|
||||||
// and then create a ivalue::Object from that class type.
|
// and then create a ivalue::Object from that class type.
|
||||||
|
|
@ -926,7 +926,7 @@ inline py::object invokeScriptMethodFromPython(
|
||||||
script::Method& callee,
|
script::Method& callee,
|
||||||
tuple_slice args,
|
tuple_slice args,
|
||||||
py::kwargs kwargs) {
|
py::kwargs kwargs) {
|
||||||
auto self = callee.owner().module_object();
|
auto self = callee.owner()._ivalue();
|
||||||
return runAndInsertCall(
|
return runAndInsertCall(
|
||||||
callee.function(),
|
callee.function(),
|
||||||
args,
|
args,
|
||||||
|
|
|
||||||
|
|
@ -725,7 +725,8 @@ void initPythonIRBindings(PyObject* module_) {
|
||||||
py::class_<ClassType, Type, std::shared_ptr<ClassType>>(m, "ClassType")
|
py::class_<ClassType, Type, std::shared_ptr<ClassType>>(m, "ClassType")
|
||||||
.def(py::init([](const std::string& qualified_name) {
|
.def(py::init([](const std::string& qualified_name) {
|
||||||
return get_python_cu()->get_class(c10::QualifiedName(qualified_name));
|
return get_python_cu()->get_class(c10::QualifiedName(qualified_name));
|
||||||
}));
|
}))
|
||||||
|
.def("name", [](ClassType& self) { return self.name()->name(); });
|
||||||
py::class_<InterfaceType, Type, std::shared_ptr<InterfaceType>>(
|
py::class_<InterfaceType, Type, std::shared_ptr<InterfaceType>>(
|
||||||
m, "InterfaceType")
|
m, "InterfaceType")
|
||||||
.def(py::init([](const std::string& qualified_name) {
|
.def(py::init([](const std::string& qualified_name) {
|
||||||
|
|
|
||||||
|
|
@ -390,9 +390,10 @@ void addFunctionToModule(Module& module, const StrongFunctionPtr& func) {
|
||||||
// Make a graph with a fake self argument
|
// Make a graph with a fake self argument
|
||||||
auto graph = func.function_->graph()->copy();
|
auto graph = func.function_->graph()->copy();
|
||||||
auto v = graph->insertInput(0, "self");
|
auto v = graph->insertInput(0, "self");
|
||||||
v->setType(module.module_object()->type());
|
v->setType(module._ivalue()->type());
|
||||||
const auto name = QualifiedName(module.name(), "forward");
|
const auto name = QualifiedName(*module.type()->name(), "forward");
|
||||||
auto method = module.class_compilation_unit()->create_function(name, graph);
|
auto method =
|
||||||
|
module._ivalue()->compilation_unit()->create_function(name, graph);
|
||||||
module.type()->addMethod(method);
|
module.type()->addMethod(method);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -403,7 +404,7 @@ bool ivalue_tags_match(const Module& lhs, const Module& rhs) {
|
||||||
IValue b;
|
IValue b;
|
||||||
};
|
};
|
||||||
std::unordered_set<const void*> visited;
|
std::unordered_set<const void*> visited;
|
||||||
std::vector<Work> work = {{lhs.module_object(), rhs.module_object()}};
|
std::vector<Work> work = {{lhs._ivalue(), rhs._ivalue()}};
|
||||||
while (!work.empty()) {
|
while (!work.empty()) {
|
||||||
Work item = work.back();
|
Work item = work.back();
|
||||||
work.pop_back();
|
work.pop_back();
|
||||||
|
|
@ -501,7 +502,7 @@ struct slot_dict_impl {
|
||||||
static void bind(const py::module& m, const char* name) {
|
static void bind(const py::module& m, const char* name) {
|
||||||
py::class_<slot_dict_impl<Policy>>(m, name)
|
py::class_<slot_dict_impl<Policy>>(m, name)
|
||||||
.def(py::init(
|
.def(py::init(
|
||||||
[](Module& m) { return slot_dict_impl<Policy>(m.module_object()); }))
|
[](Module& m) { return slot_dict_impl<Policy>(m._ivalue()); }))
|
||||||
.def("contains", &slot_dict_impl<Policy>::contains)
|
.def("contains", &slot_dict_impl<Policy>::contains)
|
||||||
.def("items", &slot_dict_impl<Policy>::items)
|
.def("items", &slot_dict_impl<Policy>::items)
|
||||||
.def("setattr", &slot_dict_impl<Policy>::setattr)
|
.def("setattr", &slot_dict_impl<Policy>::setattr)
|
||||||
|
|
@ -562,10 +563,56 @@ void initJitScriptBindings(PyObject* module) {
|
||||||
// follows.
|
// follows.
|
||||||
py::bind_map<ExtraFilesMap>(m, "ExtraFilesMap");
|
py::bind_map<ExtraFilesMap>(m, "ExtraFilesMap");
|
||||||
|
|
||||||
|
py::class_<Object>(m, "ScriptObject")
|
||||||
|
.def("_type", [](Module& m) { return m.type(); })
|
||||||
|
.def(
|
||||||
|
"_get_method",
|
||||||
|
[](Object& self, const std::string& name) -> Method {
|
||||||
|
return self.get_method(name);
|
||||||
|
},
|
||||||
|
py::keep_alive<0, 1>())
|
||||||
|
.def(
|
||||||
|
"setattr",
|
||||||
|
[](Object& self, const std::string& name, py::object value) {
|
||||||
|
TypePtr type = self.type()->getAttribute(name);
|
||||||
|
TORCH_CHECK(type, "Module has no attribute '", name, "'");
|
||||||
|
auto ivalue = toIValue(std::move(value), type);
|
||||||
|
self.setattr(name, ivalue);
|
||||||
|
})
|
||||||
|
.def(
|
||||||
|
"getattr",
|
||||||
|
[](Object& self, const std::string& name) {
|
||||||
|
return toPyObject(self.attr(name));
|
||||||
|
})
|
||||||
|
.def(
|
||||||
|
"__getattr__",
|
||||||
|
[](Object& self, const std::string& name) {
|
||||||
|
if (auto method = self.find_method(name)) {
|
||||||
|
return py::cast(*method);
|
||||||
|
}
|
||||||
|
return toPyObject(self.attr(name));
|
||||||
|
})
|
||||||
|
.def(
|
||||||
|
"hasattr",
|
||||||
|
[](Object& self, const std::string& name) {
|
||||||
|
return self.hasattr(name);
|
||||||
|
})
|
||||||
|
.def(
|
||||||
|
"_has_method",
|
||||||
|
[](Object& self, const std::string& name) {
|
||||||
|
return bool(self.find_method(name));
|
||||||
|
})
|
||||||
|
.def(
|
||||||
|
"_method_names", [](Object& self) {
|
||||||
|
return fmap(self.get_methods(), [](const Method& method) {
|
||||||
|
return method.name();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
// torch.jit.ScriptModule is a subclass of this C++ object.
|
// torch.jit.ScriptModule is a subclass of this C++ object.
|
||||||
// Methods here are prefixed with _ since they should not be
|
// Methods here are prefixed with _ since they should not be
|
||||||
// public.
|
// public.
|
||||||
py::class_<Module>(m, "ScriptModule")
|
py::class_<Module, Object>(m, "ScriptModule")
|
||||||
.def(py::init<std::string, std::shared_ptr<CompilationUnit>, bool>())
|
.def(py::init<std::string, std::shared_ptr<CompilationUnit>, bool>())
|
||||||
.def(
|
.def(
|
||||||
"save",
|
"save",
|
||||||
|
|
@ -598,46 +645,10 @@ void initJitScriptBindings(PyObject* module) {
|
||||||
py::arg("attrs") = true,
|
py::arg("attrs") = true,
|
||||||
py::arg("params") = true,
|
py::arg("params") = true,
|
||||||
py::arg("indent") = 0)
|
py::arg("indent") = 0)
|
||||||
.def(
|
|
||||||
"_define",
|
|
||||||
[](Module& m,
|
|
||||||
std::shared_ptr<ConcreteModuleType> concreteType,
|
|
||||||
const std::string& script,
|
|
||||||
ResolutionCallback rcb) {
|
|
||||||
const auto self = ModuleSelf(std::move(concreteType));
|
|
||||||
m.class_compilation_unit()->define(
|
|
||||||
m.name(), script, pythonResolver(rcb), &self);
|
|
||||||
didFinishEmitModule(m);
|
|
||||||
})
|
|
||||||
.def("_type", [](Module& m) { return m.type(); })
|
|
||||||
.def(
|
|
||||||
"_get_method",
|
|
||||||
[](Module& self, const std::string& name) -> Method {
|
|
||||||
return self.get_method(name);
|
|
||||||
},
|
|
||||||
py::keep_alive<0, 1>())
|
|
||||||
.def(
|
|
||||||
"setattr",
|
|
||||||
[](Module& self, const std::string& name, py::object value) {
|
|
||||||
TypePtr type = self.type()->getAttribute(name);
|
|
||||||
TORCH_CHECK(type, "Module has no attribute '", name, "'");
|
|
||||||
auto ivalue = toIValue(std::move(value), type);
|
|
||||||
self.setattr(name, ivalue);
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"getattr",
|
|
||||||
[](Module& self, const std::string& name) {
|
|
||||||
return toPyObject(self.attr(name));
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"hasattr",
|
|
||||||
[](Module& self, const std::string& name) {
|
|
||||||
return self.hasattr(name);
|
|
||||||
})
|
|
||||||
.def(
|
.def(
|
||||||
"_replicate_for_data_parallel",
|
"_replicate_for_data_parallel",
|
||||||
[](Module& module) {
|
[](Module& module) {
|
||||||
const ModulePtr& obj = module.module_object();
|
const ModulePtr& obj = module._ivalue();
|
||||||
auto copy = c10::ivalue::Object::create(
|
auto copy = c10::ivalue::Object::create(
|
||||||
c10::StrongTypePtr(obj->compilation_unit(), obj->type()),
|
c10::StrongTypePtr(obj->compilation_unit(), obj->type()),
|
||||||
obj->slots().size());
|
obj->slots().size());
|
||||||
|
|
@ -647,16 +658,24 @@ void initJitScriptBindings(PyObject* module) {
|
||||||
return Module(std::move(copy));
|
return Module(std::move(copy));
|
||||||
})
|
})
|
||||||
.def(
|
.def(
|
||||||
"_has_method",
|
"get_debug_state",
|
||||||
[](Module& self, const std::string& name) {
|
[](Module& self) {
|
||||||
return bool(self.find_method(name));
|
if (auto m = self.find_method("forward")) {
|
||||||
|
return m->get_executor().getDebugState();
|
||||||
|
}
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Attempted to call get_debug_state on a Module without a compiled forward()");
|
||||||
})
|
})
|
||||||
.def(
|
.def(
|
||||||
"_method_names",
|
"_define",
|
||||||
[](Module& self) {
|
[](Module& m,
|
||||||
return fmap(self.get_methods(), [](const Method& method) {
|
std::shared_ptr<ConcreteModuleType> concreteType,
|
||||||
return method.name();
|
const std::string& script,
|
||||||
});
|
ResolutionCallback rcb) {
|
||||||
|
const auto self = ModuleSelf(std::move(concreteType));
|
||||||
|
m._ivalue()->compilation_unit()->define(
|
||||||
|
*m.type()->name(), script, pythonResolver(rcb), &self);
|
||||||
|
didFinishEmitModule(m);
|
||||||
})
|
})
|
||||||
.def(
|
.def(
|
||||||
"_create_method_from_trace",
|
"_create_method_from_trace",
|
||||||
|
|
@ -672,21 +691,12 @@ void initJitScriptBindings(PyObject* module) {
|
||||||
|
|
||||||
std::shared_ptr<Graph> graph = std::get<0>(tracer::createGraphByTracing(
|
std::shared_ptr<Graph> graph = std::get<0>(tracer::createGraphByTracing(
|
||||||
func, typed_inputs, var_lookup_fn, force_outplace, &self));
|
func, typed_inputs, var_lookup_fn, force_outplace, &self));
|
||||||
const auto method_name = QualifiedName(self.name(), name);
|
const auto method_name = QualifiedName(*self.type()->name(), name);
|
||||||
auto fn = self.class_compilation_unit()->create_function(
|
auto fn = self._ivalue()->compilation_unit()->create_function(
|
||||||
method_name, graph);
|
method_name, graph);
|
||||||
self.type()->addMethod(fn);
|
self.type()->addMethod(fn);
|
||||||
didFinishEmitModule(self);
|
didFinishEmitModule(self);
|
||||||
})
|
})
|
||||||
.def(
|
|
||||||
"get_debug_state",
|
|
||||||
[](Module& self) {
|
|
||||||
if (auto m = self.find_method("forward")) {
|
|
||||||
return m->get_executor().getDebugState();
|
|
||||||
}
|
|
||||||
throw std::runtime_error(
|
|
||||||
"Attempted to call get_debug_state on a Module without a compiled forward()");
|
|
||||||
})
|
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"code",
|
"code",
|
||||||
[](Module& self) {
|
[](Module& self) {
|
||||||
|
|
@ -697,9 +707,7 @@ void initJitScriptBindings(PyObject* module) {
|
||||||
return pp.str();
|
return pp.str();
|
||||||
})
|
})
|
||||||
.def("apply", &Module::apply)
|
.def("apply", &Module::apply)
|
||||||
.def("_clone", &Module::clone)
|
.def("_clone", &Module::clone);
|
||||||
.def_property_readonly(
|
|
||||||
"name", [](const Module& self) { return self.name().name(); });
|
|
||||||
|
|
||||||
slot_dict_impl<script::detail::ParameterPolicy>::bind(m, "ParameterDict");
|
slot_dict_impl<script::detail::ParameterPolicy>::bind(m, "ParameterDict");
|
||||||
slot_dict_impl<script::detail::BufferPolicy>::bind(m, "BufferDict");
|
slot_dict_impl<script::detail::BufferPolicy>::bind(m, "BufferDict");
|
||||||
|
|
|
||||||
65
torch/csrc/jit/script/method.h
Normal file
65
torch/csrc/jit/script/method.h
Normal file
|
|
@ -0,0 +1,65 @@
|
||||||
|
|
||||||
|
#include <ATen/core/ivalue.h>
|
||||||
|
#include <ATen/core/stack.h>
|
||||||
|
#include <torch/csrc/jit/function.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
namespace script {
|
||||||
|
|
||||||
|
using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>;
|
||||||
|
|
||||||
|
// A method in a module, e.g. f in:
|
||||||
|
//
|
||||||
|
// class M(ScriptModule):
|
||||||
|
// @script_method
|
||||||
|
// def f(self, x):
|
||||||
|
// ...
|
||||||
|
// Note: because Method/Module are exposed to python these
|
||||||
|
// classes use python method naming conventions
|
||||||
|
struct TORCH_API Method {
|
||||||
|
Method(ObjectPtr owner, Function* function);
|
||||||
|
|
||||||
|
// the module that contains this method.
|
||||||
|
Module owner() const;
|
||||||
|
void run(Stack& stack);
|
||||||
|
void run(Stack&& stack) {
|
||||||
|
run(stack);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::IValue operator()(
|
||||||
|
std::vector<c10::IValue> stack,
|
||||||
|
const Kwargs& kwargs = Kwargs());
|
||||||
|
|
||||||
|
std::shared_ptr<Graph> graph() const {
|
||||||
|
return function_->graph();
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string& name() const {
|
||||||
|
return function_->name();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t num_inputs() const {
|
||||||
|
return function_->num_inputs();
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphExecutor& get_executor() {
|
||||||
|
return function_->get_executor();
|
||||||
|
}
|
||||||
|
|
||||||
|
Function& function() const {
|
||||||
|
return *function_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Methods are uniqued onwed by a single module. This raw pointer allows
|
||||||
|
// looking up the module.
|
||||||
|
ObjectPtr owner_;
|
||||||
|
|
||||||
|
// Underlying unbound function
|
||||||
|
Function* function_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace script
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
|
|
@ -14,7 +14,7 @@ namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace script {
|
namespace script {
|
||||||
|
|
||||||
static ModulePtr create_module_object(
|
static ObjectPtr create_module_object(
|
||||||
c10::QualifiedName class_name,
|
c10::QualifiedName class_name,
|
||||||
std::shared_ptr<CompilationUnit> cu,
|
std::shared_ptr<CompilationUnit> cu,
|
||||||
bool shouldMangle = false) {
|
bool shouldMangle = false) {
|
||||||
|
|
@ -33,14 +33,14 @@ static ModulePtr create_module_object(
|
||||||
}
|
}
|
||||||
|
|
||||||
Module::Module(c10::QualifiedName class_name)
|
Module::Module(c10::QualifiedName class_name)
|
||||||
: module_value_(create_module_object(
|
: Object(create_module_object(
|
||||||
std::move(class_name),
|
std::move(class_name),
|
||||||
std::make_shared<CompilationUnit>())) {}
|
std::make_shared<CompilationUnit>())) {}
|
||||||
|
|
||||||
Module::Module(
|
Module::Module(
|
||||||
std::shared_ptr<CompilationUnit> cu,
|
std::shared_ptr<CompilationUnit> cu,
|
||||||
const c10::ClassTypePtr& type)
|
const c10::ClassTypePtr& type)
|
||||||
: module_value_(c10::ivalue::Object::create(
|
: Object(c10::ivalue::Object::create(
|
||||||
c10::StrongTypePtr(std::move(cu), type),
|
c10::StrongTypePtr(std::move(cu), type),
|
||||||
type->numAttributes())) {}
|
type->numAttributes())) {}
|
||||||
|
|
||||||
|
|
@ -48,21 +48,11 @@ Module::Module(
|
||||||
c10::QualifiedName class_name,
|
c10::QualifiedName class_name,
|
||||||
std::shared_ptr<CompilationUnit> cu,
|
std::shared_ptr<CompilationUnit> cu,
|
||||||
bool shouldMangle)
|
bool shouldMangle)
|
||||||
: module_value_(create_module_object(
|
: Object(create_module_object(
|
||||||
std::move(class_name),
|
std::move(class_name),
|
||||||
std::move(cu),
|
std::move(cu),
|
||||||
shouldMangle)) {}
|
shouldMangle)) {}
|
||||||
|
|
||||||
ModulePtr Module::module_object() const {
|
|
||||||
if (!module_value_) {
|
|
||||||
// User has created a Model without assigning it to something already
|
|
||||||
// loaded. This is done in tests, and when using the .define method.
|
|
||||||
module_value_ =
|
|
||||||
create_module_object("Module", std::make_shared<CompilationUnit>());
|
|
||||||
}
|
|
||||||
return module_value_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// first class mode runs models as first class objects,
|
// first class mode runs models as first class objects,
|
||||||
// and does not force inlining everywhere. This is experimental
|
// and does not force inlining everywhere. This is experimental
|
||||||
// as we bring up the system since it will degrade performance
|
// as we bring up the system since it will degrade performance
|
||||||
|
|
@ -155,21 +145,15 @@ Module Method::owner() const {
|
||||||
return Module(owner_);
|
return Module(owner_);
|
||||||
}
|
}
|
||||||
void Method::run(Stack& stack) {
|
void Method::run(Stack& stack) {
|
||||||
stack.insert(stack.begin(), owner().module_object());
|
stack.insert(stack.begin(), owner()._ivalue());
|
||||||
function_->run(stack);
|
function_->run(stack);
|
||||||
}
|
}
|
||||||
|
|
||||||
IValue Method::operator()(std::vector<IValue> stack, const Kwargs& kwargs) {
|
IValue Method::operator()(std::vector<IValue> stack, const Kwargs& kwargs) {
|
||||||
stack.insert(stack.begin(), owner().module_object());
|
stack.insert(stack.begin(), owner()._ivalue());
|
||||||
return (*function_)(std::move(stack), kwargs);
|
return (*function_)(std::move(stack), kwargs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Module::define(const std::string& src, const ResolverPtr& resolver) {
|
|
||||||
const auto self = SimpleSelf(type());
|
|
||||||
class_compilation_unit()->define(
|
|
||||||
name(), src, resolver ? resolver : script::nativeResolver(), &self);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Module::clone_method(
|
void Module::clone_method(
|
||||||
const Module& orig,
|
const Module& orig,
|
||||||
const Function& method,
|
const Function& method,
|
||||||
|
|
@ -195,7 +179,7 @@ void Module::clone_method(
|
||||||
auto schema = method.getSchema().cloneWithRemappedTypes(type_remap_fn);
|
auto schema = method.getSchema().cloneWithRemappedTypes(type_remap_fn);
|
||||||
const auto this_method_name = getNameForMethod(method.name());
|
const auto this_method_name = getNameForMethod(method.name());
|
||||||
auto copied =
|
auto copied =
|
||||||
class_compilation_unit()->create_function(this_method_name, graph);
|
_ivalue()->compilation_unit()->create_function(this_method_name, graph);
|
||||||
type()->addMethod(copied);
|
type()->addMethod(copied);
|
||||||
copied->setSchema(std::move(schema));
|
copied->setSchema(std::move(schema));
|
||||||
}
|
}
|
||||||
|
|
@ -206,8 +190,7 @@ void Module::clone_method(const Module& orig, const std::string& name) {
|
||||||
while (!to_scan.empty()) {
|
while (!to_scan.empty()) {
|
||||||
auto entry = to_scan.back();
|
auto entry = to_scan.back();
|
||||||
to_scan.pop_back();
|
to_scan.pop_back();
|
||||||
type_remap[entry.first.module_object()->type()] =
|
type_remap[entry.first._ivalue()->type()] = entry.second._ivalue()->type();
|
||||||
entry.second.module_object()->type();
|
|
||||||
for (const NameModule& s : entry.first.named_children()) {
|
for (const NameModule& s : entry.first.named_children()) {
|
||||||
to_scan.emplace_back(
|
to_scan.emplace_back(
|
||||||
s.value, Module(entry.second.attr(s.name).toObject()));
|
s.value, Module(entry.second.attr(s.name).toObject()));
|
||||||
|
|
@ -223,16 +206,16 @@ Module Module::clone() const {
|
||||||
|
|
||||||
Module Module::clone_impl(
|
Module Module::clone_impl(
|
||||||
std::unordered_map<TypePtr, TypePtr>& type_remap) const {
|
std::unordered_map<TypePtr, TypePtr>& type_remap) const {
|
||||||
// Create a new module_object in the same compilation unit.
|
// Create a new _ivalue in the same compilation unit.
|
||||||
// The name is the same as for the original module, but it'll be mangled.
|
// The name is the same as for the original module, but it'll be mangled.
|
||||||
// The class type is also created from scratch.
|
// The class type is also created from scratch.
|
||||||
Module r(name(), class_compilation_unit(), true);
|
Module r(*type()->name(), _ivalue()->compilation_unit(), true);
|
||||||
type_remap[type()] = r.type();
|
type_remap[type()] = r.type();
|
||||||
|
|
||||||
// Copy slots. If a slot is a module - recursively clone it.
|
// Copy slots. If a slot is a module - recursively clone it.
|
||||||
size_t N = type()->numAttributes();
|
size_t N = type()->numAttributes();
|
||||||
for (size_t i = 0; i < N; ++i) {
|
for (size_t i = 0; i < N; ++i) {
|
||||||
IValue s = module_object()->getSlot(i);
|
IValue s = _ivalue()->getSlot(i);
|
||||||
if (type()->getAttribute(i)->is_module()) {
|
if (type()->getAttribute(i)->is_module()) {
|
||||||
const Module& orig = Module(s.toObject());
|
const Module& orig = Module(s.toObject());
|
||||||
Module cloned = orig.clone_impl(type_remap);
|
Module cloned = orig.clone_impl(type_remap);
|
||||||
|
|
@ -256,8 +239,8 @@ Module Module::clone_impl(
|
||||||
|
|
||||||
void Module::train(bool on) {
|
void Module::train(bool on) {
|
||||||
for (Module m : modules()) {
|
for (Module m : modules()) {
|
||||||
if (auto slot = m.module_object()->type()->findAttributeSlot("training")) {
|
if (auto slot = m._ivalue()->type()->findAttributeSlot("training")) {
|
||||||
m.module_object()->setSlot(*slot, on);
|
m._ivalue()->setSlot(*slot, on);
|
||||||
} else {
|
} else {
|
||||||
TORCH_INTERNAL_ASSERT("'training' attribute not found");
|
TORCH_INTERNAL_ASSERT("'training' attribute not found");
|
||||||
}
|
}
|
||||||
|
|
@ -267,7 +250,7 @@ void Module::train(bool on) {
|
||||||
IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const {
|
IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const {
|
||||||
// Look up the class
|
// Look up the class
|
||||||
const auto classType =
|
const auto classType =
|
||||||
class_compilation_unit()->get_class(c10::QualifiedName(name));
|
_ivalue()->compilation_unit()->get_class(c10::QualifiedName(name));
|
||||||
if (!classType) {
|
if (!classType) {
|
||||||
AT_ERROR(
|
AT_ERROR(
|
||||||
"Could not find class with name: '",
|
"Could not find class with name: '",
|
||||||
|
|
@ -278,7 +261,7 @@ IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const {
|
||||||
// Create a bare object with correct number of slots
|
// Create a bare object with correct number of slots
|
||||||
const size_t numAttrs = classType->numAttributes();
|
const size_t numAttrs = classType->numAttributes();
|
||||||
auto obj = c10::ivalue::Object::create(
|
auto obj = c10::ivalue::Object::create(
|
||||||
c10::StrongTypePtr(class_compilation_unit(), classType), numAttrs);
|
c10::StrongTypePtr(_ivalue()->compilation_unit(), classType), numAttrs);
|
||||||
|
|
||||||
// Invoke the `__init__()` of the class with the arguments provided.
|
// Invoke the `__init__()` of the class with the arguments provided.
|
||||||
Stack stackWithSelf = {obj};
|
Stack stackWithSelf = {obj};
|
||||||
|
|
@ -319,15 +302,6 @@ named_parameter_list Module::named_parameters(bool recurse) const {
|
||||||
return named_parameter_list(*this, recurse, /*return_module=*/false);
|
return named_parameter_list(*this, recurse, /*return_module=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::optional<Method> Module::find_method(const std::string& basename) const {
|
|
||||||
for (Function* fn : type()->methods()) {
|
|
||||||
if (fn->name() == basename) {
|
|
||||||
return Method(module_object(), fn);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return c10::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
attribute_list Module::attributes(bool recurse) const {
|
attribute_list Module::attributes(bool recurse) const {
|
||||||
return attribute_list(*this, recurse, /*return_module=*/false);
|
return attribute_list(*this, recurse, /*return_module=*/false);
|
||||||
}
|
}
|
||||||
|
|
@ -380,7 +354,7 @@ std::string Module::dump_to_str(
|
||||||
methods_ss << " }" << std::endl;
|
methods_ss << " }" << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
ss << "module " << name().qualifiedName() << " {" << std::endl;
|
ss << "module " << type()->name()->qualifiedName() << " {" << std::endl;
|
||||||
ss << " parameters {" << std::endl;
|
ss << " parameters {" << std::endl;
|
||||||
ss << torch::jit::jit_log_prefix(" ", parameters_ss.str());
|
ss << torch::jit::jit_log_prefix(" ", parameters_ss.str());
|
||||||
ss << " }" << std::endl;
|
ss << " }" << std::endl;
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
#include <torch/csrc/jit/ir.h>
|
#include <torch/csrc/jit/ir.h>
|
||||||
#include <torch/csrc/jit/named_value.h>
|
#include <torch/csrc/jit/named_value.h>
|
||||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||||
|
#include <torch/csrc/jit/script/object.h>
|
||||||
#include <torch/csrc/jit/source_range.h>
|
#include <torch/csrc/jit/source_range.h>
|
||||||
|
|
||||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
|
|
@ -84,73 +85,17 @@ using named_buffer_list =
|
||||||
|
|
||||||
using ModuleLookup = std::function<Module(const std::vector<std::string>&)>;
|
using ModuleLookup = std::function<Module(const std::vector<std::string>&)>;
|
||||||
|
|
||||||
// A method in a module, e.g. f in:
|
struct TORCH_API Module : public Object {
|
||||||
//
|
|
||||||
// class M(ScriptModule):
|
|
||||||
// @script_method
|
|
||||||
// def f(self, x):
|
|
||||||
// ...
|
|
||||||
// Note: because Method/Module are exposed to python these
|
|
||||||
// classes use python method naming conventions
|
|
||||||
struct TORCH_API Method {
|
|
||||||
Method(ModulePtr owner, Function* function);
|
|
||||||
|
|
||||||
// the module that contains this method.
|
|
||||||
Module owner() const;
|
|
||||||
void run(Stack& stack);
|
|
||||||
void run(Stack&& stack) {
|
|
||||||
run(stack);
|
|
||||||
}
|
|
||||||
|
|
||||||
IValue operator()(std::vector<IValue> stack, const Kwargs& kwargs = Kwargs());
|
|
||||||
|
|
||||||
std::shared_ptr<Graph> graph() const {
|
|
||||||
return function_->graph();
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::string& name() const {
|
|
||||||
return function_->name();
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t num_inputs() const {
|
|
||||||
return function_->num_inputs();
|
|
||||||
}
|
|
||||||
|
|
||||||
GraphExecutor& get_executor() {
|
|
||||||
return function_->get_executor();
|
|
||||||
}
|
|
||||||
|
|
||||||
Function& function() const {
|
|
||||||
return *function_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
// Methods are uniqued onwed by a single module. This raw pointer allows
|
|
||||||
// looking up the module.
|
|
||||||
ModulePtr owner_;
|
|
||||||
|
|
||||||
// Underlying unbound function
|
|
||||||
// This is the _lowered_ function and is different than the
|
|
||||||
// first-class function in class_compilation_unit()
|
|
||||||
Function* function_;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TORCH_API Module {
|
|
||||||
explicit Module(c10::QualifiedName class_name);
|
explicit Module(c10::QualifiedName class_name);
|
||||||
Module(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type);
|
Module(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type);
|
||||||
|
Module() {}
|
||||||
Module(
|
Module(
|
||||||
c10::QualifiedName,
|
c10::QualifiedName,
|
||||||
std::shared_ptr<CompilationUnit> cu,
|
std::shared_ptr<CompilationUnit> cu,
|
||||||
bool shouldMangle = false);
|
bool shouldMangle = false);
|
||||||
// module_value_ null and will be lazily initialized if is needed
|
Module(ModulePtr module_value) : Object(std::move(module_value)) {}
|
||||||
Module() {}
|
|
||||||
Module(ModulePtr module_value) : module_value_(std::move(module_value)) {}
|
|
||||||
~Module() {}
|
~Module() {}
|
||||||
|
|
||||||
const c10::QualifiedName& name() const {
|
|
||||||
return *module_object()->type()->name();
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_optimized(bool o) {
|
void set_optimized(bool o) {
|
||||||
AT_WARN(
|
AT_WARN(
|
||||||
"Module::set_optimized() is deprecated and has no effect. "
|
"Module::set_optimized() is deprecated and has no effect. "
|
||||||
|
|
@ -174,14 +119,14 @@ struct TORCH_API Module {
|
||||||
// whether a slot is a parameter to be able to classify it.
|
// whether a slot is a parameter to be able to classify it.
|
||||||
void register_buffer(const std::string& name, at::Tensor v) {
|
void register_buffer(const std::string& name, at::Tensor v) {
|
||||||
type()->addOrCheckAttribute(name, TensorType::get());
|
type()->addOrCheckAttribute(name, TensorType::get());
|
||||||
module_object()->setAttr(name, std::move(v));
|
_ivalue()->setAttr(name, std::move(v));
|
||||||
}
|
}
|
||||||
void register_parameter(
|
void register_parameter(
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
at::Tensor v,
|
at::Tensor v,
|
||||||
bool is_buffer) {
|
bool is_buffer) {
|
||||||
type()->addOrCheckAttribute(name, TensorType::get(), !is_buffer);
|
type()->addOrCheckAttribute(name, TensorType::get(), !is_buffer);
|
||||||
module_object()->setAttr(name, std::move(v));
|
_ivalue()->setAttr(name, std::move(v));
|
||||||
}
|
}
|
||||||
void register_attribute(
|
void register_attribute(
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
|
|
@ -189,51 +134,11 @@ struct TORCH_API Module {
|
||||||
IValue v,
|
IValue v,
|
||||||
bool is_param = false) {
|
bool is_param = false) {
|
||||||
type()->addOrCheckAttribute(name, t, is_param);
|
type()->addOrCheckAttribute(name, t, is_param);
|
||||||
module_object()->setAttr(name, std::move(v));
|
_ivalue()->setAttr(name, std::move(v));
|
||||||
}
|
}
|
||||||
void register_module(const std::string& name, const Module& module) {
|
void register_module(const std::string& name, const Module& module) {
|
||||||
type()->addOrCheckAttribute(name, module.type());
|
type()->addOrCheckAttribute(name, module.type());
|
||||||
module_object()->setAttr(name, module.module_object());
|
_ivalue()->setAttr(name, module._ivalue());
|
||||||
}
|
|
||||||
|
|
||||||
void setattr(const std::string& name, IValue v) {
|
|
||||||
size_t slot = module_object()->type()->getAttributeSlot(name);
|
|
||||||
const TypePtr& expected = module_object()->type()->getAttribute(slot);
|
|
||||||
TORCH_CHECK(expected, "Module has no attribute '", name, "'");
|
|
||||||
TORCH_CHECK(
|
|
||||||
v.type()->isSubtypeOf(expected),
|
|
||||||
"Expected a value of type '",
|
|
||||||
expected->python_str(),
|
|
||||||
"' for field '",
|
|
||||||
name,
|
|
||||||
"', but found '",
|
|
||||||
v.type()->python_str(),
|
|
||||||
"'");
|
|
||||||
module_object()->setSlot(slot, std::move(v));
|
|
||||||
}
|
|
||||||
|
|
||||||
IValue attr(const std::string& name) const {
|
|
||||||
return module_object()->getAttr(name);
|
|
||||||
}
|
|
||||||
|
|
||||||
IValue attr(const std::string& name, IValue or_else) const {
|
|
||||||
if (auto r = module_object()->type()->findAttributeSlot(name)) {
|
|
||||||
return module_object()->getSlot(*r);
|
|
||||||
}
|
|
||||||
return or_else;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool hasattr(const std::string& name) const {
|
|
||||||
return module_object()->type()->findAttributeSlot(name).has_value();
|
|
||||||
}
|
|
||||||
|
|
||||||
// each module owns its method. The reference returned here
|
|
||||||
// is guarenteed to stay valid until this module has been destroyed
|
|
||||||
Method get_method(const std::string& name) const {
|
|
||||||
if (auto method = find_method(name)) {
|
|
||||||
return *method;
|
|
||||||
}
|
|
||||||
AT_ERROR("Method '", name, "' is not defined.");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void apply(const std::function<void(Module&)>& fn);
|
void apply(const std::function<void(Module&)>& fn);
|
||||||
|
|
@ -265,16 +170,6 @@ struct TORCH_API Module {
|
||||||
bool print_param_values,
|
bool print_param_values,
|
||||||
int level) const;
|
int level) const;
|
||||||
|
|
||||||
const std::vector<Method> get_methods() const {
|
|
||||||
return fmap(
|
|
||||||
type()->methods(),
|
|
||||||
[&](Function* func) {
|
|
||||||
return Method(module_object(), func);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
c10::optional<Method> find_method(const std::string& basename) const;
|
|
||||||
|
|
||||||
/// Enables "training" mode.
|
/// Enables "training" mode.
|
||||||
void train(bool on = true);
|
void train(bool on = true);
|
||||||
/// Calls train(false) to enable "eval" mode.
|
/// Calls train(false) to enable "eval" mode.
|
||||||
|
|
@ -311,24 +206,6 @@ struct TORCH_API Module {
|
||||||
/// effect.
|
/// effect.
|
||||||
void to(at::Device device, bool non_blocking = false);
|
void to(at::Device device, bool non_blocking = false);
|
||||||
|
|
||||||
/// Run a method from this module.
|
|
||||||
///
|
|
||||||
/// For example:
|
|
||||||
/// @code
|
|
||||||
/// IValue output = module->run("relu_script", a, b);
|
|
||||||
/// @endcode
|
|
||||||
///
|
|
||||||
/// To get a compile a module from a source string, see torch::jit::compile
|
|
||||||
///
|
|
||||||
/// @param method_name The name of the method to run
|
|
||||||
/// @param args Arguments to be passed to the method
|
|
||||||
/// @return An IValue containing the return value (or values if it is a tuple)
|
|
||||||
/// from the method
|
|
||||||
template <typename... Types>
|
|
||||||
IValue run_method(const std::string& method_name, Types&&... args) {
|
|
||||||
return get_method(method_name)({IValue(std::forward<Types>(args))...});
|
|
||||||
}
|
|
||||||
|
|
||||||
void save(
|
void save(
|
||||||
std::ostream& out,
|
std::ostream& out,
|
||||||
const ExtraFilesMap& extra_files = ExtraFilesMap()) const;
|
const ExtraFilesMap& extra_files = ExtraFilesMap()) const;
|
||||||
|
|
@ -350,18 +227,6 @@ struct TORCH_API Module {
|
||||||
|
|
||||||
void clone_method(const Module& orig, const std::string& name);
|
void clone_method(const Module& orig, const std::string& name);
|
||||||
|
|
||||||
ModulePtr module_object() const;
|
|
||||||
|
|
||||||
ClassTypePtr type() const {
|
|
||||||
return module_object()->type();
|
|
||||||
}
|
|
||||||
std::shared_ptr<CompilationUnit> class_compilation_unit() const {
|
|
||||||
return module_object()->compilation_unit();
|
|
||||||
}
|
|
||||||
|
|
||||||
// so that C++ users can easily add methods
|
|
||||||
void define(const std::string& src, const ResolverPtr& resolver = nullptr);
|
|
||||||
|
|
||||||
template <typename... Types>
|
template <typename... Types>
|
||||||
IValue create_class(const c10::QualifiedName& name, Types&&... args) const {
|
IValue create_class(const c10::QualifiedName& name, Types&&... args) const {
|
||||||
return create_class(name, {IValue(std::forward<Types>(args))...});
|
return create_class(name, {IValue(std::forward<Types>(args))...});
|
||||||
|
|
@ -369,10 +234,6 @@ struct TORCH_API Module {
|
||||||
|
|
||||||
IValue create_class(const c10::QualifiedName& name, Stack stack) const;
|
IValue create_class(const c10::QualifiedName& name, Stack stack) const;
|
||||||
|
|
||||||
size_t num_slots() const {
|
|
||||||
return module_object()->slots().size();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Module clone_impl(std::unordered_map<TypePtr, TypePtr>& type_remap) const;
|
Module clone_impl(std::unordered_map<TypePtr, TypePtr>& type_remap) const;
|
||||||
|
|
||||||
|
|
@ -382,16 +243,13 @@ struct TORCH_API Module {
|
||||||
const std::unordered_map<TypePtr, TypePtr>& type_remap);
|
const std::unordered_map<TypePtr, TypePtr>& type_remap);
|
||||||
|
|
||||||
c10::QualifiedName getNameForMethod(std::string basename) const {
|
c10::QualifiedName getNameForMethod(std::string basename) const {
|
||||||
return QualifiedName(name(), basename);
|
return QualifiedName(*type()->name(), basename);
|
||||||
}
|
}
|
||||||
|
|
||||||
void to_impl(
|
void to_impl(
|
||||||
const c10::optional<at::Device>& device,
|
const c10::optional<at::Device>& device,
|
||||||
const c10::optional<at::ScalarType>& dtype,
|
const c10::optional<at::ScalarType>& dtype,
|
||||||
bool non_blocking);
|
bool non_blocking);
|
||||||
|
|
||||||
// mutable be we lazily initialize in module_object.
|
|
||||||
mutable ModulePtr module_value_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
@ -460,8 +318,8 @@ struct slot_iterator_impl {
|
||||||
return cursors_.back();
|
return cursors_.back();
|
||||||
}
|
}
|
||||||
IValue cur() const {
|
IValue cur() const {
|
||||||
return return_module() ? top().module_.module_object()
|
return return_module() ? top().module_._ivalue()
|
||||||
: top().module_.module_object()->getSlot(top().i_);
|
: top().module_._ivalue()->getSlot(top().i_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// advance to the next slot in a depth first pre-order traversal of the
|
// advance to the next slot in a depth first pre-order traversal of the
|
||||||
|
|
@ -487,11 +345,7 @@ struct slot_iterator_impl {
|
||||||
// if the current thing is a module, we have to scan it for recursive
|
// if the current thing is a module, we have to scan it for recursive
|
||||||
// traversals. We do this by adding a new SlotCursor to track the traversal.
|
// traversals. We do this by adding a new SlotCursor to track the traversal.
|
||||||
if (recurse_ &&
|
if (recurse_ &&
|
||||||
top()
|
top().module_._ivalue()->type()->getAttribute(top().i_)->is_module()) {
|
||||||
.module_.module_object()
|
|
||||||
->type()
|
|
||||||
->getAttribute(top().i_)
|
|
||||||
->is_module()) {
|
|
||||||
cursors_.emplace_back(SlotCursor{cur().toModule(), 0});
|
cursors_.emplace_back(SlotCursor{cur().toModule(), 0});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
@ -502,7 +356,7 @@ struct slot_iterator_impl {
|
||||||
// otherwise, we have to continue advancing.
|
// otherwise, we have to continue advancing.
|
||||||
bool valid() const {
|
bool valid() const {
|
||||||
return top().i_ < int64_t(top().module_.num_slots()) &&
|
return top().i_ < int64_t(top().module_.num_slots()) &&
|
||||||
Policy::valid(top().module_.module_object()->type(), top().i_);
|
Policy::valid(top().module_._ivalue()->type(), top().i_);
|
||||||
}
|
}
|
||||||
void while_not_valid_next() {
|
void while_not_valid_next() {
|
||||||
// advance iteration until we are either at the end (cursors_.empty())
|
// advance iteration until we are either at the end (cursors_.empty())
|
||||||
|
|
|
||||||
44
torch/csrc/jit/script/object.cpp
Normal file
44
torch/csrc/jit/script/object.cpp
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
#include <torch/csrc/jit/script/object.h>
|
||||||
|
|
||||||
|
#include <ATen/core/jit_type.h>
|
||||||
|
#include <torch/csrc/jit/script/compilation_unit.h>
|
||||||
|
#include <torch/csrc/jit/script/resolver.h>
|
||||||
|
#include <torch/csrc/jit/script/sugared_value.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
namespace script {
|
||||||
|
|
||||||
|
Object::Object(
|
||||||
|
std::shared_ptr<CompilationUnit> cu,
|
||||||
|
const c10::ClassTypePtr& type)
|
||||||
|
: Object(c10::ivalue::Object::create(
|
||||||
|
c10::StrongTypePtr(std::move(cu), type),
|
||||||
|
type->numAttributes())) {}
|
||||||
|
|
||||||
|
ObjectPtr Object::_ivalue() const {
|
||||||
|
TORCH_INTERNAL_ASSERT(_ivalue_);
|
||||||
|
return _ivalue_;
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::optional<Method> Object::find_method(const std::string& basename) const {
|
||||||
|
for (Function* fn : type()->methods()) {
|
||||||
|
if (fn->name() == basename) {
|
||||||
|
return Method(_ivalue(), fn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c10::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Object::define(const std::string& src, const ResolverPtr& resolver) {
|
||||||
|
const auto self = SimpleSelf(type());
|
||||||
|
_ivalue()->compilation_unit()->define(
|
||||||
|
*type()->name(),
|
||||||
|
src,
|
||||||
|
resolver ? resolver : script::nativeResolver(),
|
||||||
|
&self);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace script
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
110
torch/csrc/jit/script/object.h
Normal file
110
torch/csrc/jit/script/object.h
Normal file
|
|
@ -0,0 +1,110 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/core/ivalue.h>
|
||||||
|
#include <torch/csrc/jit/script/method.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
namespace script {
|
||||||
|
|
||||||
|
struct Resolver;
|
||||||
|
using ResolverPtr = std::shared_ptr<Resolver>;
|
||||||
|
|
||||||
|
using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>;
|
||||||
|
|
||||||
|
struct TORCH_API Object {
|
||||||
|
Object() {}
|
||||||
|
Object(ObjectPtr _ivalue) : _ivalue_(std::move(_ivalue)) {}
|
||||||
|
Object(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type);
|
||||||
|
Object(
|
||||||
|
c10::QualifiedName,
|
||||||
|
std::shared_ptr<CompilationUnit> cu,
|
||||||
|
bool shouldMangle = false);
|
||||||
|
|
||||||
|
ObjectPtr _ivalue() const;
|
||||||
|
|
||||||
|
c10::ClassTypePtr type() const {
|
||||||
|
return _ivalue()->type();
|
||||||
|
}
|
||||||
|
|
||||||
|
void setattr(const std::string& name, c10::IValue v) {
|
||||||
|
size_t slot = _ivalue()->type()->getAttributeSlot(name);
|
||||||
|
const c10::TypePtr& expected = _ivalue()->type()->getAttribute(slot);
|
||||||
|
TORCH_CHECK(expected, "Module has no attribute '", name, "'");
|
||||||
|
TORCH_CHECK(
|
||||||
|
v.type()->isSubtypeOf(expected),
|
||||||
|
"Expected a value of type '",
|
||||||
|
expected->python_str(),
|
||||||
|
"' for field '",
|
||||||
|
name,
|
||||||
|
"', but found '",
|
||||||
|
v.type()->python_str(),
|
||||||
|
"'");
|
||||||
|
_ivalue()->setSlot(slot, std::move(v));
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::IValue attr(const std::string& name) const {
|
||||||
|
return _ivalue()->getAttr(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::IValue attr(const std::string& name, c10::IValue or_else) const {
|
||||||
|
if (auto r = _ivalue()->type()->findAttributeSlot(name)) {
|
||||||
|
return _ivalue()->getSlot(*r);
|
||||||
|
}
|
||||||
|
return or_else;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasattr(const std::string& name) const {
|
||||||
|
return _ivalue()->type()->findAttributeSlot(name).has_value();
|
||||||
|
}
|
||||||
|
|
||||||
|
// each object owns its methods. The reference returned here
|
||||||
|
// is guarenteed to stay valid until this module has been destroyed
|
||||||
|
Method get_method(const std::string& name) const {
|
||||||
|
if (auto method = find_method(name)) {
|
||||||
|
return *method;
|
||||||
|
}
|
||||||
|
AT_ERROR("Method '", name, "' is not defined.");
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<Method> get_methods() const {
|
||||||
|
return fmap(type()->methods(), [&](Function* func) {
|
||||||
|
return Method(_ivalue(), func);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::optional<Method> find_method(const std::string& basename) const;
|
||||||
|
|
||||||
|
/// Run a method from this module.
|
||||||
|
///
|
||||||
|
/// For example:
|
||||||
|
/// @code
|
||||||
|
/// IValue output = module->run("relu_script", a, b);
|
||||||
|
/// @endcode
|
||||||
|
///
|
||||||
|
/// To get a compile a module from a source string, see torch::jit::compile
|
||||||
|
///
|
||||||
|
/// @param method_name The name of the method to run
|
||||||
|
/// @param args Arguments to be passed to the method
|
||||||
|
/// @return An IValue containing the return value (or values if it is a tuple)
|
||||||
|
/// from the method
|
||||||
|
template <typename... Types>
|
||||||
|
IValue run_method(const std::string& method_name, Types&&... args) {
|
||||||
|
return get_method(method_name)({IValue(std::forward<Types>(args))...});
|
||||||
|
}
|
||||||
|
|
||||||
|
// so that C++ users can easily add methods
|
||||||
|
void define(const std::string& src, const ResolverPtr& resolver = nullptr);
|
||||||
|
|
||||||
|
size_t num_slots() const {
|
||||||
|
return _ivalue()->slots().size();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// mutable be we lazily initialize in module_object.
|
||||||
|
mutable ObjectPtr _ivalue_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace script
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
|
|
@ -290,7 +290,7 @@ static void gatherParametersAndBuffers(
|
||||||
const std::string& prefix) {
|
const std::string& prefix) {
|
||||||
Graph& g = *self_value->owningGraph();
|
Graph& g = *self_value->owningGraph();
|
||||||
|
|
||||||
state->setValue(self.module_object(), self_value);
|
state->setValue(self._ivalue(), self_value);
|
||||||
|
|
||||||
auto self_ty = self.type();
|
auto self_ty = self.type();
|
||||||
for (const script::NameValue& s : self.named_attributes(/*recurse=*/false)) {
|
for (const script::NameValue& s : self.named_attributes(/*recurse=*/false)) {
|
||||||
|
|
@ -328,8 +328,8 @@ std::pair<std::shared_ptr<TracingState>, Stack> trace(
|
||||||
// if we are a module, then make sure the modules parameters are in the map
|
// if we are a module, then make sure the modules parameters are in the map
|
||||||
// and mapped to accesses to the self object
|
// and mapped to accesses to the self object
|
||||||
if (self) {
|
if (self) {
|
||||||
Value* self_value =
|
Value* self_value = state->graph->insertInput(0, "self")->setType(
|
||||||
state->graph->insertInput(0, "self")->setType(self->module_object()->type());
|
self->_ivalue()->type());
|
||||||
gatherParametersAndBuffers(state, self_value, *self, {"__module"});
|
gatherParametersAndBuffers(state, self_value, *self, {"__module"});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,7 @@ ScriptModuleOutput ScriptModuleBenchmark::runOnce(
|
||||||
function.getSchema(),
|
function.getSchema(),
|
||||||
std::move(args),
|
std::move(args),
|
||||||
std::move(kwargs),
|
std::move(kwargs),
|
||||||
model_.module_object());
|
model_._ivalue());
|
||||||
return function(std::move(stack));
|
return function(std::move(stack));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -100,7 +100,7 @@ void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs) {
|
||||||
model_.get_method("forward").function().getSchema(),
|
model_.get_method("forward").function().getSchema(),
|
||||||
std::move(args),
|
std::move(args),
|
||||||
std::move(kwargs),
|
std::move(kwargs),
|
||||||
model_.module_object());
|
model_._ivalue());
|
||||||
inputs_.emplace_back(std::move(stack));
|
inputs_.emplace_back(std::move(stack));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ namespace detail {
|
||||||
template <class Input, class Output, class Model>
|
template <class Input, class Output, class Model>
|
||||||
class BenchmarkHelper {
|
class BenchmarkHelper {
|
||||||
public:
|
public:
|
||||||
BenchmarkHelper(): initialized_{false} {}
|
BenchmarkHelper();
|
||||||
explicit BenchmarkHelper(Model model): model_(model), initialized_(true) {}
|
explicit BenchmarkHelper(Model model): model_(model), initialized_(true) {}
|
||||||
|
|
||||||
// This method to be used in benchmark() method
|
// This method to be used in benchmark() method
|
||||||
|
|
@ -108,7 +108,14 @@ typedef BenchmarkHelper<
|
||||||
at::IValue,
|
at::IValue,
|
||||||
jit::script::Module>
|
jit::script::Module>
|
||||||
ScriptModuleBenchmark;
|
ScriptModuleBenchmark;
|
||||||
|
template <>
|
||||||
|
inline BenchmarkHelper<ScriptModuleInput, at::IValue, jit::script::Module>::BenchmarkHelper()
|
||||||
|
: model_("Module", std::make_shared<jit::script::CompilationUnit>()),
|
||||||
|
initialized_(false) {}
|
||||||
typedef BenchmarkHelper<ModuleInput, py::object, py::object> ModuleBenchmark;
|
typedef BenchmarkHelper<ModuleInput, py::object, py::object> ModuleBenchmark;
|
||||||
|
template <>
|
||||||
|
inline BenchmarkHelper<ModuleInput, py::object, py::object>::BenchmarkHelper()
|
||||||
|
: initialized_(false) {}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void ScriptModuleBenchmark::runOnce(
|
void ScriptModuleBenchmark::runOnce(
|
||||||
|
|
|
||||||
|
|
@ -1627,9 +1627,9 @@ if _enabled:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def original_name(self):
|
def original_name(self):
|
||||||
if type(self) == self._c.name:
|
if type(self) == str(self._c._type().name()):
|
||||||
return ''
|
return ''
|
||||||
return self._c.name
|
return str(self._c._type().name())
|
||||||
|
|
||||||
def define(self, src):
|
def define(self, src):
|
||||||
# We use frames_up=1 to get to the proper surrounding scope. The stack
|
# We use frames_up=1 to get to the proper surrounding scope. The stack
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user