mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Factor Module into Object and Module
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29500 Test Plan: Imported from OSS Differential Revision: D18463064 Pulled By: jamesr66a fbshipit-source-id: d37bef242a8626593d4b8754042152cfc0f0acb2
This commit is contained in:
parent
14946a8891
commit
18bdf97dbb
|
|
@ -44,7 +44,7 @@ void dump_opnames(const script::Module& m, std::unordered_set<std::string>& opna
|
|||
}
|
||||
}
|
||||
for (const auto& sub_m : m.children()) {
|
||||
std::cout << "sub module name: " << sub_m.name().qualifiedName() << std::endl;
|
||||
std::cout << "sub module name: " << sub_m.type()->name()->qualifiedName() << std::endl;
|
||||
dump_opnames(sub_m, opnames);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -444,6 +444,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
|||
${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/script/logging.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/script/object.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/script/jit_exception.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/source_range_serialization.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
|
||||
|
|
|
|||
|
|
@ -89,12 +89,12 @@ void testScriptObject() {
|
|||
Module m2("m2");
|
||||
std::vector<at::Tensor> constantTable;
|
||||
import_libs(
|
||||
m1.class_compilation_unit(),
|
||||
m1._ivalue()->compilation_unit(),
|
||||
"__torch__.FooTest",
|
||||
std::make_shared<Source>(classSrcs1),
|
||||
constantTable);
|
||||
import_libs(
|
||||
m2.class_compilation_unit(),
|
||||
m2._ivalue()->compilation_unit(),
|
||||
"__torch__.FooTest",
|
||||
std::make_shared<Source>(classSrcs2),
|
||||
constantTable);
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ void testModuleInterfaceSerialization() {
|
|||
parentMod.register_attribute(
|
||||
"subMod",
|
||||
cu->get_interface("__torch__.OneForward"),
|
||||
subMod.module_object(),
|
||||
subMod._ivalue(),
|
||||
/*is_parameter=*/false);
|
||||
parentMod.define(parentForward, nativeResolver());
|
||||
ASSERT_TRUE(parentMod.hasattr("subMod"));
|
||||
|
|
|
|||
|
|
@ -164,6 +164,7 @@ libtorch_sources = [
|
|||
"torch/csrc/jit/hooks_for_testing.cpp",
|
||||
"torch/csrc/jit/script/builtin_functions.cpp",
|
||||
"torch/csrc/jit/script/module.cpp",
|
||||
"torch/csrc/jit/script/object.cpp",
|
||||
"torch/csrc/jit/tracer.cpp",
|
||||
"torch/csrc/jit/fuser/kernel_cache.cpp",
|
||||
"torch/csrc/jit/fuser/compiler.cpp",
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ namespace serialize {
|
|||
class TORCH_API OutputArchive final {
|
||||
public:
|
||||
explicit OutputArchive(std::shared_ptr<jit::script::CompilationUnit> cu);
|
||||
explicit OutputArchive() : cu_(std::make_shared<jit::script::CompilationUnit>()) {}
|
||||
explicit OutputArchive() : cu_(std::make_shared<jit::script::CompilationUnit>()), module_("__torch__.Module", cu_) {}
|
||||
|
||||
// Move is allowed.
|
||||
OutputArchive(OutputArchive&&) = default;
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
namespace torch {
|
||||
namespace serialize {
|
||||
|
||||
InputArchive::InputArchive() {}
|
||||
InputArchive::InputArchive() : module_("Module", std::make_shared<jit::script::CompilationUnit>()) {}
|
||||
|
||||
void InputArchive::read(const std::string& key, c10::IValue& ivalue) {
|
||||
ivalue = module_.attr(key);
|
||||
|
|
|
|||
|
|
@ -546,7 +546,7 @@ class ScriptModuleSerializer {
|
|||
C10_LOG_API_USAGE_ONCE("torch.script.save");
|
||||
writeExtraFiles(module, extra_files);
|
||||
// Serialize the model object
|
||||
writeArchive("data", module.module_object());
|
||||
writeArchive("data", module._ivalue());
|
||||
// Then we werialize all code info.
|
||||
writeCode(module.type());
|
||||
// The tensor constants from the code are written to a separate archive
|
||||
|
|
|
|||
|
|
@ -257,9 +257,9 @@ void ScriptModuleDeserializer::LEGACY_moduleSetState(
|
|||
// TODO: once modules are first class in the interpreter and methods are not
|
||||
// lowered, change this to `module->run_method("__setstate__", {state});`
|
||||
if (setstate->num_inputs() == 1) {
|
||||
setstate->run({module.module_object()});
|
||||
setstate->run({module._ivalue()});
|
||||
} else if (setstate->num_inputs() == 2) {
|
||||
setstate->run({module.module_object(), state});
|
||||
setstate->run({module._ivalue(), state});
|
||||
} else {
|
||||
AT_ERROR("Unexpected schema on '__setstate__'");
|
||||
}
|
||||
|
|
@ -348,11 +348,11 @@ script::Module ScriptModuleDeserializer::LEGACY_convertModule(
|
|||
LEGACY_pickled_ivalues_.at(module_def.get_state_attribute_id()));
|
||||
}
|
||||
|
||||
const ClassTypePtr& module_type = module.module_object()->type();
|
||||
const ClassTypePtr& module_type = module._ivalue()->type();
|
||||
for (size_t i = 0, N = module_type->numAttributes(); i < N; ++i) {
|
||||
// Verify that all the non-optional attributes have been initialized
|
||||
// TODO: Issue #20497
|
||||
const IValue& v = module.module_object()->getSlot(i);
|
||||
const IValue& v = module._ivalue()->getSlot(i);
|
||||
if (module_type->getAttribute(i)->kind() != TypeKind::OptionalType) {
|
||||
TORCH_CHECK(
|
||||
!v.isNone(),
|
||||
|
|
|
|||
|
|
@ -193,7 +193,7 @@ struct SourceImporterImpl : public Resolver,
|
|||
const script::Module& mod,
|
||||
const std::shared_ptr<Source>& src) {
|
||||
auto self = SimpleSelf(mod.type());
|
||||
c10::QualifiedName prefix = mod.name();
|
||||
c10::QualifiedName prefix = *mod.type()->name();
|
||||
Parser p(src);
|
||||
|
||||
parsePossibleVersionNumber(p.lexer());
|
||||
|
|
|
|||
|
|
@ -267,7 +267,7 @@ void initJITBindings(PyObject* module) {
|
|||
.def(
|
||||
"_jit_pass_lower_graph",
|
||||
[](std::shared_ptr<Graph>& graph, const script::Module& self) {
|
||||
return LowerGraph(*graph, self.module_object());
|
||||
return LowerGraph(*graph, self._ivalue());
|
||||
})
|
||||
.def("_jit_pass_loop_unrolling", UnrollLoops)
|
||||
.def(
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ void fillQConfigMap(
|
|||
} else {
|
||||
qconfig = parent_qconfig;
|
||||
}
|
||||
map[module.module_object()] = qconfig;
|
||||
map[module._ivalue()] = qconfig;
|
||||
|
||||
for (const script::NameModule& s : module.named_children()) {
|
||||
std::string child_key;
|
||||
|
|
@ -39,8 +39,7 @@ void fillQConfigMap(
|
|||
} else {
|
||||
child_key = key + "." + s.name;
|
||||
}
|
||||
fillQConfigMap(
|
||||
s.value.module_object(), qconfig_dict, map, child_key, qconfig);
|
||||
fillQConfigMap(s.value._ivalue(), qconfig_dict, map, child_key, qconfig);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -184,7 +183,8 @@ Node* InsertObserversHelper::insertObserverFor(
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
script::Module observer_module;
|
||||
script::Module observer_module(
|
||||
"Module", std::make_shared<script::CompilationUnit>());
|
||||
if (isWeightOfConvOrLinear(v)) {
|
||||
TORCH_CHECK(v->uses().size() == 1, "We only support weight being used by one node.");
|
||||
observer_module = std::get<1>(qconfig);
|
||||
|
|
@ -275,7 +275,7 @@ graph(%input, %weight, %bias, %4):
|
|||
void InsertObserversHelper::insertObservers(
|
||||
script::Module& module,
|
||||
const std::string& method_name) {
|
||||
if (!module_qconfig_map_.count(module.module_object())) {
|
||||
if (!module_qconfig_map_.count(module._ivalue())) {
|
||||
// the module is added by us, e.g.: observer module
|
||||
return;
|
||||
}
|
||||
|
|
@ -304,7 +304,7 @@ void InsertObserversHelper::insertObservers(
|
|||
for (size_t idx = 1; idx < method.num_inputs(); ++idx) {
|
||||
auto& v = graph->inputs()[idx];
|
||||
if (!values_to_skip_.count(v) && valueNeedsToBeQuantized(v)) {
|
||||
auto qconfig = module_qconfig_map_.at(module.module_object());
|
||||
auto qconfig = module_qconfig_map_.at(module._ivalue());
|
||||
if (qconfig) {
|
||||
auto observer_node =
|
||||
insertObserverFor(v, v->owningGraph(), module, qconfig.value());
|
||||
|
|
@ -339,7 +339,8 @@ void InsertObserversHelper::insertObservers(
|
|||
// the child module.
|
||||
auto module_instance = n->inputs()[0];
|
||||
auto module_method_name = n->s(attr::name);
|
||||
script::Module callee_module;
|
||||
script::Module callee_module(
|
||||
"Module", std::make_shared<script::CompilationUnit>());
|
||||
if (module_instance->node()->kind() == prim::GetAttr) {
|
||||
auto child_module_name = module_instance->node()->s(attr::name);
|
||||
callee_module = module.attr(child_module_name).toModule();
|
||||
|
|
@ -365,7 +366,7 @@ void InsertObserversHelper::insertObservers(
|
|||
|
||||
// Actually add observer nodes.
|
||||
for (Value* v : values_to_observe) {
|
||||
auto qconfig = module_qconfig_map_.at(module.module_object());
|
||||
auto qconfig = module_qconfig_map_.at(module._ivalue());
|
||||
// Skip inserting observer if no qconfig is specified
|
||||
if (qconfig) {
|
||||
insertObserverFor(v, v->owningGraph(), module, qconfig.value());
|
||||
|
|
@ -455,7 +456,7 @@ class QuantizeHelper {
|
|||
// O(N) where N is number of observer moduels with this optimization
|
||||
for (int64_t i = observer_modules_to_remove_.size() - 1; i >= 0; --i) {
|
||||
auto observer_name = observer_modules_to_remove_[i];
|
||||
module_.module_object()->unsafeRemoveAttr(observer_name);
|
||||
module_._ivalue()->unsafeRemoveAttr(observer_name);
|
||||
module_.type()->unsafeRemoveAttribute(observer_name);
|
||||
}
|
||||
// Destroy observer forward calls
|
||||
|
|
@ -826,7 +827,8 @@ graph(%self, %x):
|
|||
|
||||
script::Method method = current.get_method("forward");
|
||||
GRAPH_DUMP(
|
||||
current.name().name() + "::forward() before Conv2d-BatchNorm2d folding",
|
||||
current.type()->name()->name() +
|
||||
"::forward() before Conv2d-BatchNorm2d folding",
|
||||
method.graph());
|
||||
const auto& matches = findPatternMatches(pattern_graph, *method.graph());
|
||||
|
||||
|
|
|
|||
|
|
@ -458,7 +458,7 @@ inline IValue toIValue(
|
|||
auto classType = type->expect<ClassType>();
|
||||
if (auto mod = script::as_module(py::cast<py::object>(obj))) {
|
||||
// if obj is already a ScriptModule, just return its ivalue
|
||||
return mod.value().module_object();
|
||||
return mod.value()._ivalue();
|
||||
}
|
||||
// otherwise is a normal class object, we create a fresh
|
||||
// ivalue::Object to use from the py object.
|
||||
|
|
@ -487,7 +487,7 @@ inline IValue toIValue(
|
|||
IValue res;
|
||||
if (auto mod = script::as_module(py::cast<py::object>(obj))) {
|
||||
classType = mod.value().type();
|
||||
res = mod.value().module_object();
|
||||
res = mod.value()._ivalue();
|
||||
} else {
|
||||
// We inspect the value to found the compiled TorchScript class
|
||||
// and then create a ivalue::Object from that class type.
|
||||
|
|
@ -926,7 +926,7 @@ inline py::object invokeScriptMethodFromPython(
|
|||
script::Method& callee,
|
||||
tuple_slice args,
|
||||
py::kwargs kwargs) {
|
||||
auto self = callee.owner().module_object();
|
||||
auto self = callee.owner()._ivalue();
|
||||
return runAndInsertCall(
|
||||
callee.function(),
|
||||
args,
|
||||
|
|
|
|||
|
|
@ -725,7 +725,8 @@ void initPythonIRBindings(PyObject* module_) {
|
|||
py::class_<ClassType, Type, std::shared_ptr<ClassType>>(m, "ClassType")
|
||||
.def(py::init([](const std::string& qualified_name) {
|
||||
return get_python_cu()->get_class(c10::QualifiedName(qualified_name));
|
||||
}));
|
||||
}))
|
||||
.def("name", [](ClassType& self) { return self.name()->name(); });
|
||||
py::class_<InterfaceType, Type, std::shared_ptr<InterfaceType>>(
|
||||
m, "InterfaceType")
|
||||
.def(py::init([](const std::string& qualified_name) {
|
||||
|
|
|
|||
|
|
@ -390,9 +390,10 @@ void addFunctionToModule(Module& module, const StrongFunctionPtr& func) {
|
|||
// Make a graph with a fake self argument
|
||||
auto graph = func.function_->graph()->copy();
|
||||
auto v = graph->insertInput(0, "self");
|
||||
v->setType(module.module_object()->type());
|
||||
const auto name = QualifiedName(module.name(), "forward");
|
||||
auto method = module.class_compilation_unit()->create_function(name, graph);
|
||||
v->setType(module._ivalue()->type());
|
||||
const auto name = QualifiedName(*module.type()->name(), "forward");
|
||||
auto method =
|
||||
module._ivalue()->compilation_unit()->create_function(name, graph);
|
||||
module.type()->addMethod(method);
|
||||
}
|
||||
|
||||
|
|
@ -403,7 +404,7 @@ bool ivalue_tags_match(const Module& lhs, const Module& rhs) {
|
|||
IValue b;
|
||||
};
|
||||
std::unordered_set<const void*> visited;
|
||||
std::vector<Work> work = {{lhs.module_object(), rhs.module_object()}};
|
||||
std::vector<Work> work = {{lhs._ivalue(), rhs._ivalue()}};
|
||||
while (!work.empty()) {
|
||||
Work item = work.back();
|
||||
work.pop_back();
|
||||
|
|
@ -501,7 +502,7 @@ struct slot_dict_impl {
|
|||
static void bind(const py::module& m, const char* name) {
|
||||
py::class_<slot_dict_impl<Policy>>(m, name)
|
||||
.def(py::init(
|
||||
[](Module& m) { return slot_dict_impl<Policy>(m.module_object()); }))
|
||||
[](Module& m) { return slot_dict_impl<Policy>(m._ivalue()); }))
|
||||
.def("contains", &slot_dict_impl<Policy>::contains)
|
||||
.def("items", &slot_dict_impl<Policy>::items)
|
||||
.def("setattr", &slot_dict_impl<Policy>::setattr)
|
||||
|
|
@ -562,10 +563,56 @@ void initJitScriptBindings(PyObject* module) {
|
|||
// follows.
|
||||
py::bind_map<ExtraFilesMap>(m, "ExtraFilesMap");
|
||||
|
||||
py::class_<Object>(m, "ScriptObject")
|
||||
.def("_type", [](Module& m) { return m.type(); })
|
||||
.def(
|
||||
"_get_method",
|
||||
[](Object& self, const std::string& name) -> Method {
|
||||
return self.get_method(name);
|
||||
},
|
||||
py::keep_alive<0, 1>())
|
||||
.def(
|
||||
"setattr",
|
||||
[](Object& self, const std::string& name, py::object value) {
|
||||
TypePtr type = self.type()->getAttribute(name);
|
||||
TORCH_CHECK(type, "Module has no attribute '", name, "'");
|
||||
auto ivalue = toIValue(std::move(value), type);
|
||||
self.setattr(name, ivalue);
|
||||
})
|
||||
.def(
|
||||
"getattr",
|
||||
[](Object& self, const std::string& name) {
|
||||
return toPyObject(self.attr(name));
|
||||
})
|
||||
.def(
|
||||
"__getattr__",
|
||||
[](Object& self, const std::string& name) {
|
||||
if (auto method = self.find_method(name)) {
|
||||
return py::cast(*method);
|
||||
}
|
||||
return toPyObject(self.attr(name));
|
||||
})
|
||||
.def(
|
||||
"hasattr",
|
||||
[](Object& self, const std::string& name) {
|
||||
return self.hasattr(name);
|
||||
})
|
||||
.def(
|
||||
"_has_method",
|
||||
[](Object& self, const std::string& name) {
|
||||
return bool(self.find_method(name));
|
||||
})
|
||||
.def(
|
||||
"_method_names", [](Object& self) {
|
||||
return fmap(self.get_methods(), [](const Method& method) {
|
||||
return method.name();
|
||||
});
|
||||
});
|
||||
|
||||
// torch.jit.ScriptModule is a subclass of this C++ object.
|
||||
// Methods here are prefixed with _ since they should not be
|
||||
// public.
|
||||
py::class_<Module>(m, "ScriptModule")
|
||||
py::class_<Module, Object>(m, "ScriptModule")
|
||||
.def(py::init<std::string, std::shared_ptr<CompilationUnit>, bool>())
|
||||
.def(
|
||||
"save",
|
||||
|
|
@ -598,46 +645,10 @@ void initJitScriptBindings(PyObject* module) {
|
|||
py::arg("attrs") = true,
|
||||
py::arg("params") = true,
|
||||
py::arg("indent") = 0)
|
||||
.def(
|
||||
"_define",
|
||||
[](Module& m,
|
||||
std::shared_ptr<ConcreteModuleType> concreteType,
|
||||
const std::string& script,
|
||||
ResolutionCallback rcb) {
|
||||
const auto self = ModuleSelf(std::move(concreteType));
|
||||
m.class_compilation_unit()->define(
|
||||
m.name(), script, pythonResolver(rcb), &self);
|
||||
didFinishEmitModule(m);
|
||||
})
|
||||
.def("_type", [](Module& m) { return m.type(); })
|
||||
.def(
|
||||
"_get_method",
|
||||
[](Module& self, const std::string& name) -> Method {
|
||||
return self.get_method(name);
|
||||
},
|
||||
py::keep_alive<0, 1>())
|
||||
.def(
|
||||
"setattr",
|
||||
[](Module& self, const std::string& name, py::object value) {
|
||||
TypePtr type = self.type()->getAttribute(name);
|
||||
TORCH_CHECK(type, "Module has no attribute '", name, "'");
|
||||
auto ivalue = toIValue(std::move(value), type);
|
||||
self.setattr(name, ivalue);
|
||||
})
|
||||
.def(
|
||||
"getattr",
|
||||
[](Module& self, const std::string& name) {
|
||||
return toPyObject(self.attr(name));
|
||||
})
|
||||
.def(
|
||||
"hasattr",
|
||||
[](Module& self, const std::string& name) {
|
||||
return self.hasattr(name);
|
||||
})
|
||||
.def(
|
||||
"_replicate_for_data_parallel",
|
||||
[](Module& module) {
|
||||
const ModulePtr& obj = module.module_object();
|
||||
const ModulePtr& obj = module._ivalue();
|
||||
auto copy = c10::ivalue::Object::create(
|
||||
c10::StrongTypePtr(obj->compilation_unit(), obj->type()),
|
||||
obj->slots().size());
|
||||
|
|
@ -647,16 +658,24 @@ void initJitScriptBindings(PyObject* module) {
|
|||
return Module(std::move(copy));
|
||||
})
|
||||
.def(
|
||||
"_has_method",
|
||||
[](Module& self, const std::string& name) {
|
||||
return bool(self.find_method(name));
|
||||
"get_debug_state",
|
||||
[](Module& self) {
|
||||
if (auto m = self.find_method("forward")) {
|
||||
return m->get_executor().getDebugState();
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"Attempted to call get_debug_state on a Module without a compiled forward()");
|
||||
})
|
||||
.def(
|
||||
"_method_names",
|
||||
[](Module& self) {
|
||||
return fmap(self.get_methods(), [](const Method& method) {
|
||||
return method.name();
|
||||
});
|
||||
"_define",
|
||||
[](Module& m,
|
||||
std::shared_ptr<ConcreteModuleType> concreteType,
|
||||
const std::string& script,
|
||||
ResolutionCallback rcb) {
|
||||
const auto self = ModuleSelf(std::move(concreteType));
|
||||
m._ivalue()->compilation_unit()->define(
|
||||
*m.type()->name(), script, pythonResolver(rcb), &self);
|
||||
didFinishEmitModule(m);
|
||||
})
|
||||
.def(
|
||||
"_create_method_from_trace",
|
||||
|
|
@ -672,21 +691,12 @@ void initJitScriptBindings(PyObject* module) {
|
|||
|
||||
std::shared_ptr<Graph> graph = std::get<0>(tracer::createGraphByTracing(
|
||||
func, typed_inputs, var_lookup_fn, force_outplace, &self));
|
||||
const auto method_name = QualifiedName(self.name(), name);
|
||||
auto fn = self.class_compilation_unit()->create_function(
|
||||
const auto method_name = QualifiedName(*self.type()->name(), name);
|
||||
auto fn = self._ivalue()->compilation_unit()->create_function(
|
||||
method_name, graph);
|
||||
self.type()->addMethod(fn);
|
||||
didFinishEmitModule(self);
|
||||
})
|
||||
.def(
|
||||
"get_debug_state",
|
||||
[](Module& self) {
|
||||
if (auto m = self.find_method("forward")) {
|
||||
return m->get_executor().getDebugState();
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"Attempted to call get_debug_state on a Module without a compiled forward()");
|
||||
})
|
||||
.def_property_readonly(
|
||||
"code",
|
||||
[](Module& self) {
|
||||
|
|
@ -697,9 +707,7 @@ void initJitScriptBindings(PyObject* module) {
|
|||
return pp.str();
|
||||
})
|
||||
.def("apply", &Module::apply)
|
||||
.def("_clone", &Module::clone)
|
||||
.def_property_readonly(
|
||||
"name", [](const Module& self) { return self.name().name(); });
|
||||
.def("_clone", &Module::clone);
|
||||
|
||||
slot_dict_impl<script::detail::ParameterPolicy>::bind(m, "ParameterDict");
|
||||
slot_dict_impl<script::detail::BufferPolicy>::bind(m, "BufferDict");
|
||||
|
|
|
|||
65
torch/csrc/jit/script/method.h
Normal file
65
torch/csrc/jit/script/method.h
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <torch/csrc/jit/function.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>;
|
||||
|
||||
// A method in a module, e.g. f in:
|
||||
//
|
||||
// class M(ScriptModule):
|
||||
// @script_method
|
||||
// def f(self, x):
|
||||
// ...
|
||||
// Note: because Method/Module are exposed to python these
|
||||
// classes use python method naming conventions
|
||||
struct TORCH_API Method {
|
||||
Method(ObjectPtr owner, Function* function);
|
||||
|
||||
// the module that contains this method.
|
||||
Module owner() const;
|
||||
void run(Stack& stack);
|
||||
void run(Stack&& stack) {
|
||||
run(stack);
|
||||
}
|
||||
|
||||
c10::IValue operator()(
|
||||
std::vector<c10::IValue> stack,
|
||||
const Kwargs& kwargs = Kwargs());
|
||||
|
||||
std::shared_ptr<Graph> graph() const {
|
||||
return function_->graph();
|
||||
}
|
||||
|
||||
const std::string& name() const {
|
||||
return function_->name();
|
||||
}
|
||||
|
||||
size_t num_inputs() const {
|
||||
return function_->num_inputs();
|
||||
}
|
||||
|
||||
GraphExecutor& get_executor() {
|
||||
return function_->get_executor();
|
||||
}
|
||||
|
||||
Function& function() const {
|
||||
return *function_;
|
||||
}
|
||||
|
||||
private:
|
||||
// Methods are uniqued onwed by a single module. This raw pointer allows
|
||||
// looking up the module.
|
||||
ObjectPtr owner_;
|
||||
|
||||
// Underlying unbound function
|
||||
Function* function_;
|
||||
};
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -14,7 +14,7 @@ namespace torch {
|
|||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
static ModulePtr create_module_object(
|
||||
static ObjectPtr create_module_object(
|
||||
c10::QualifiedName class_name,
|
||||
std::shared_ptr<CompilationUnit> cu,
|
||||
bool shouldMangle = false) {
|
||||
|
|
@ -33,14 +33,14 @@ static ModulePtr create_module_object(
|
|||
}
|
||||
|
||||
Module::Module(c10::QualifiedName class_name)
|
||||
: module_value_(create_module_object(
|
||||
: Object(create_module_object(
|
||||
std::move(class_name),
|
||||
std::make_shared<CompilationUnit>())) {}
|
||||
|
||||
Module::Module(
|
||||
std::shared_ptr<CompilationUnit> cu,
|
||||
const c10::ClassTypePtr& type)
|
||||
: module_value_(c10::ivalue::Object::create(
|
||||
: Object(c10::ivalue::Object::create(
|
||||
c10::StrongTypePtr(std::move(cu), type),
|
||||
type->numAttributes())) {}
|
||||
|
||||
|
|
@ -48,21 +48,11 @@ Module::Module(
|
|||
c10::QualifiedName class_name,
|
||||
std::shared_ptr<CompilationUnit> cu,
|
||||
bool shouldMangle)
|
||||
: module_value_(create_module_object(
|
||||
: Object(create_module_object(
|
||||
std::move(class_name),
|
||||
std::move(cu),
|
||||
shouldMangle)) {}
|
||||
|
||||
ModulePtr Module::module_object() const {
|
||||
if (!module_value_) {
|
||||
// User has created a Model without assigning it to something already
|
||||
// loaded. This is done in tests, and when using the .define method.
|
||||
module_value_ =
|
||||
create_module_object("Module", std::make_shared<CompilationUnit>());
|
||||
}
|
||||
return module_value_;
|
||||
}
|
||||
|
||||
// first class mode runs models as first class objects,
|
||||
// and does not force inlining everywhere. This is experimental
|
||||
// as we bring up the system since it will degrade performance
|
||||
|
|
@ -155,21 +145,15 @@ Module Method::owner() const {
|
|||
return Module(owner_);
|
||||
}
|
||||
void Method::run(Stack& stack) {
|
||||
stack.insert(stack.begin(), owner().module_object());
|
||||
stack.insert(stack.begin(), owner()._ivalue());
|
||||
function_->run(stack);
|
||||
}
|
||||
|
||||
IValue Method::operator()(std::vector<IValue> stack, const Kwargs& kwargs) {
|
||||
stack.insert(stack.begin(), owner().module_object());
|
||||
stack.insert(stack.begin(), owner()._ivalue());
|
||||
return (*function_)(std::move(stack), kwargs);
|
||||
}
|
||||
|
||||
void Module::define(const std::string& src, const ResolverPtr& resolver) {
|
||||
const auto self = SimpleSelf(type());
|
||||
class_compilation_unit()->define(
|
||||
name(), src, resolver ? resolver : script::nativeResolver(), &self);
|
||||
}
|
||||
|
||||
void Module::clone_method(
|
||||
const Module& orig,
|
||||
const Function& method,
|
||||
|
|
@ -195,7 +179,7 @@ void Module::clone_method(
|
|||
auto schema = method.getSchema().cloneWithRemappedTypes(type_remap_fn);
|
||||
const auto this_method_name = getNameForMethod(method.name());
|
||||
auto copied =
|
||||
class_compilation_unit()->create_function(this_method_name, graph);
|
||||
_ivalue()->compilation_unit()->create_function(this_method_name, graph);
|
||||
type()->addMethod(copied);
|
||||
copied->setSchema(std::move(schema));
|
||||
}
|
||||
|
|
@ -206,8 +190,7 @@ void Module::clone_method(const Module& orig, const std::string& name) {
|
|||
while (!to_scan.empty()) {
|
||||
auto entry = to_scan.back();
|
||||
to_scan.pop_back();
|
||||
type_remap[entry.first.module_object()->type()] =
|
||||
entry.second.module_object()->type();
|
||||
type_remap[entry.first._ivalue()->type()] = entry.second._ivalue()->type();
|
||||
for (const NameModule& s : entry.first.named_children()) {
|
||||
to_scan.emplace_back(
|
||||
s.value, Module(entry.second.attr(s.name).toObject()));
|
||||
|
|
@ -223,16 +206,16 @@ Module Module::clone() const {
|
|||
|
||||
Module Module::clone_impl(
|
||||
std::unordered_map<TypePtr, TypePtr>& type_remap) const {
|
||||
// Create a new module_object in the same compilation unit.
|
||||
// Create a new _ivalue in the same compilation unit.
|
||||
// The name is the same as for the original module, but it'll be mangled.
|
||||
// The class type is also created from scratch.
|
||||
Module r(name(), class_compilation_unit(), true);
|
||||
Module r(*type()->name(), _ivalue()->compilation_unit(), true);
|
||||
type_remap[type()] = r.type();
|
||||
|
||||
// Copy slots. If a slot is a module - recursively clone it.
|
||||
size_t N = type()->numAttributes();
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
IValue s = module_object()->getSlot(i);
|
||||
IValue s = _ivalue()->getSlot(i);
|
||||
if (type()->getAttribute(i)->is_module()) {
|
||||
const Module& orig = Module(s.toObject());
|
||||
Module cloned = orig.clone_impl(type_remap);
|
||||
|
|
@ -256,8 +239,8 @@ Module Module::clone_impl(
|
|||
|
||||
void Module::train(bool on) {
|
||||
for (Module m : modules()) {
|
||||
if (auto slot = m.module_object()->type()->findAttributeSlot("training")) {
|
||||
m.module_object()->setSlot(*slot, on);
|
||||
if (auto slot = m._ivalue()->type()->findAttributeSlot("training")) {
|
||||
m._ivalue()->setSlot(*slot, on);
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT("'training' attribute not found");
|
||||
}
|
||||
|
|
@ -267,7 +250,7 @@ void Module::train(bool on) {
|
|||
IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const {
|
||||
// Look up the class
|
||||
const auto classType =
|
||||
class_compilation_unit()->get_class(c10::QualifiedName(name));
|
||||
_ivalue()->compilation_unit()->get_class(c10::QualifiedName(name));
|
||||
if (!classType) {
|
||||
AT_ERROR(
|
||||
"Could not find class with name: '",
|
||||
|
|
@ -278,7 +261,7 @@ IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const {
|
|||
// Create a bare object with correct number of slots
|
||||
const size_t numAttrs = classType->numAttributes();
|
||||
auto obj = c10::ivalue::Object::create(
|
||||
c10::StrongTypePtr(class_compilation_unit(), classType), numAttrs);
|
||||
c10::StrongTypePtr(_ivalue()->compilation_unit(), classType), numAttrs);
|
||||
|
||||
// Invoke the `__init__()` of the class with the arguments provided.
|
||||
Stack stackWithSelf = {obj};
|
||||
|
|
@ -319,15 +302,6 @@ named_parameter_list Module::named_parameters(bool recurse) const {
|
|||
return named_parameter_list(*this, recurse, /*return_module=*/false);
|
||||
}
|
||||
|
||||
c10::optional<Method> Module::find_method(const std::string& basename) const {
|
||||
for (Function* fn : type()->methods()) {
|
||||
if (fn->name() == basename) {
|
||||
return Method(module_object(), fn);
|
||||
}
|
||||
}
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
attribute_list Module::attributes(bool recurse) const {
|
||||
return attribute_list(*this, recurse, /*return_module=*/false);
|
||||
}
|
||||
|
|
@ -380,7 +354,7 @@ std::string Module::dump_to_str(
|
|||
methods_ss << " }" << std::endl;
|
||||
}
|
||||
|
||||
ss << "module " << name().qualifiedName() << " {" << std::endl;
|
||||
ss << "module " << type()->name()->qualifiedName() << " {" << std::endl;
|
||||
ss << " parameters {" << std::endl;
|
||||
ss << torch::jit::jit_log_prefix(" ", parameters_ss.str());
|
||||
ss << " }" << std::endl;
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/jit/named_value.h>
|
||||
#include <torch/csrc/jit/passes/shape_analysis.h>
|
||||
#include <torch/csrc/jit/script/object.h>
|
||||
#include <torch/csrc/jit/source_range.h>
|
||||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
|
@ -84,73 +85,17 @@ using named_buffer_list =
|
|||
|
||||
using ModuleLookup = std::function<Module(const std::vector<std::string>&)>;
|
||||
|
||||
// A method in a module, e.g. f in:
|
||||
//
|
||||
// class M(ScriptModule):
|
||||
// @script_method
|
||||
// def f(self, x):
|
||||
// ...
|
||||
// Note: because Method/Module are exposed to python these
|
||||
// classes use python method naming conventions
|
||||
struct TORCH_API Method {
|
||||
Method(ModulePtr owner, Function* function);
|
||||
|
||||
// the module that contains this method.
|
||||
Module owner() const;
|
||||
void run(Stack& stack);
|
||||
void run(Stack&& stack) {
|
||||
run(stack);
|
||||
}
|
||||
|
||||
IValue operator()(std::vector<IValue> stack, const Kwargs& kwargs = Kwargs());
|
||||
|
||||
std::shared_ptr<Graph> graph() const {
|
||||
return function_->graph();
|
||||
}
|
||||
|
||||
const std::string& name() const {
|
||||
return function_->name();
|
||||
}
|
||||
|
||||
size_t num_inputs() const {
|
||||
return function_->num_inputs();
|
||||
}
|
||||
|
||||
GraphExecutor& get_executor() {
|
||||
return function_->get_executor();
|
||||
}
|
||||
|
||||
Function& function() const {
|
||||
return *function_;
|
||||
}
|
||||
|
||||
private:
|
||||
// Methods are uniqued onwed by a single module. This raw pointer allows
|
||||
// looking up the module.
|
||||
ModulePtr owner_;
|
||||
|
||||
// Underlying unbound function
|
||||
// This is the _lowered_ function and is different than the
|
||||
// first-class function in class_compilation_unit()
|
||||
Function* function_;
|
||||
};
|
||||
|
||||
struct TORCH_API Module {
|
||||
struct TORCH_API Module : public Object {
|
||||
explicit Module(c10::QualifiedName class_name);
|
||||
Module(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type);
|
||||
Module() {}
|
||||
Module(
|
||||
c10::QualifiedName,
|
||||
std::shared_ptr<CompilationUnit> cu,
|
||||
bool shouldMangle = false);
|
||||
// module_value_ null and will be lazily initialized if is needed
|
||||
Module() {}
|
||||
Module(ModulePtr module_value) : module_value_(std::move(module_value)) {}
|
||||
Module(ModulePtr module_value) : Object(std::move(module_value)) {}
|
||||
~Module() {}
|
||||
|
||||
const c10::QualifiedName& name() const {
|
||||
return *module_object()->type()->name();
|
||||
}
|
||||
|
||||
void set_optimized(bool o) {
|
||||
AT_WARN(
|
||||
"Module::set_optimized() is deprecated and has no effect. "
|
||||
|
|
@ -174,14 +119,14 @@ struct TORCH_API Module {
|
|||
// whether a slot is a parameter to be able to classify it.
|
||||
void register_buffer(const std::string& name, at::Tensor v) {
|
||||
type()->addOrCheckAttribute(name, TensorType::get());
|
||||
module_object()->setAttr(name, std::move(v));
|
||||
_ivalue()->setAttr(name, std::move(v));
|
||||
}
|
||||
void register_parameter(
|
||||
const std::string& name,
|
||||
at::Tensor v,
|
||||
bool is_buffer) {
|
||||
type()->addOrCheckAttribute(name, TensorType::get(), !is_buffer);
|
||||
module_object()->setAttr(name, std::move(v));
|
||||
_ivalue()->setAttr(name, std::move(v));
|
||||
}
|
||||
void register_attribute(
|
||||
const std::string& name,
|
||||
|
|
@ -189,51 +134,11 @@ struct TORCH_API Module {
|
|||
IValue v,
|
||||
bool is_param = false) {
|
||||
type()->addOrCheckAttribute(name, t, is_param);
|
||||
module_object()->setAttr(name, std::move(v));
|
||||
_ivalue()->setAttr(name, std::move(v));
|
||||
}
|
||||
void register_module(const std::string& name, const Module& module) {
|
||||
type()->addOrCheckAttribute(name, module.type());
|
||||
module_object()->setAttr(name, module.module_object());
|
||||
}
|
||||
|
||||
void setattr(const std::string& name, IValue v) {
|
||||
size_t slot = module_object()->type()->getAttributeSlot(name);
|
||||
const TypePtr& expected = module_object()->type()->getAttribute(slot);
|
||||
TORCH_CHECK(expected, "Module has no attribute '", name, "'");
|
||||
TORCH_CHECK(
|
||||
v.type()->isSubtypeOf(expected),
|
||||
"Expected a value of type '",
|
||||
expected->python_str(),
|
||||
"' for field '",
|
||||
name,
|
||||
"', but found '",
|
||||
v.type()->python_str(),
|
||||
"'");
|
||||
module_object()->setSlot(slot, std::move(v));
|
||||
}
|
||||
|
||||
IValue attr(const std::string& name) const {
|
||||
return module_object()->getAttr(name);
|
||||
}
|
||||
|
||||
IValue attr(const std::string& name, IValue or_else) const {
|
||||
if (auto r = module_object()->type()->findAttributeSlot(name)) {
|
||||
return module_object()->getSlot(*r);
|
||||
}
|
||||
return or_else;
|
||||
}
|
||||
|
||||
bool hasattr(const std::string& name) const {
|
||||
return module_object()->type()->findAttributeSlot(name).has_value();
|
||||
}
|
||||
|
||||
// each module owns its method. The reference returned here
|
||||
// is guarenteed to stay valid until this module has been destroyed
|
||||
Method get_method(const std::string& name) const {
|
||||
if (auto method = find_method(name)) {
|
||||
return *method;
|
||||
}
|
||||
AT_ERROR("Method '", name, "' is not defined.");
|
||||
_ivalue()->setAttr(name, module._ivalue());
|
||||
}
|
||||
|
||||
void apply(const std::function<void(Module&)>& fn);
|
||||
|
|
@ -265,16 +170,6 @@ struct TORCH_API Module {
|
|||
bool print_param_values,
|
||||
int level) const;
|
||||
|
||||
const std::vector<Method> get_methods() const {
|
||||
return fmap(
|
||||
type()->methods(),
|
||||
[&](Function* func) {
|
||||
return Method(module_object(), func);
|
||||
});
|
||||
}
|
||||
|
||||
c10::optional<Method> find_method(const std::string& basename) const;
|
||||
|
||||
/// Enables "training" mode.
|
||||
void train(bool on = true);
|
||||
/// Calls train(false) to enable "eval" mode.
|
||||
|
|
@ -311,24 +206,6 @@ struct TORCH_API Module {
|
|||
/// effect.
|
||||
void to(at::Device device, bool non_blocking = false);
|
||||
|
||||
/// Run a method from this module.
|
||||
///
|
||||
/// For example:
|
||||
/// @code
|
||||
/// IValue output = module->run("relu_script", a, b);
|
||||
/// @endcode
|
||||
///
|
||||
/// To get a compile a module from a source string, see torch::jit::compile
|
||||
///
|
||||
/// @param method_name The name of the method to run
|
||||
/// @param args Arguments to be passed to the method
|
||||
/// @return An IValue containing the return value (or values if it is a tuple)
|
||||
/// from the method
|
||||
template <typename... Types>
|
||||
IValue run_method(const std::string& method_name, Types&&... args) {
|
||||
return get_method(method_name)({IValue(std::forward<Types>(args))...});
|
||||
}
|
||||
|
||||
void save(
|
||||
std::ostream& out,
|
||||
const ExtraFilesMap& extra_files = ExtraFilesMap()) const;
|
||||
|
|
@ -350,18 +227,6 @@ struct TORCH_API Module {
|
|||
|
||||
void clone_method(const Module& orig, const std::string& name);
|
||||
|
||||
ModulePtr module_object() const;
|
||||
|
||||
ClassTypePtr type() const {
|
||||
return module_object()->type();
|
||||
}
|
||||
std::shared_ptr<CompilationUnit> class_compilation_unit() const {
|
||||
return module_object()->compilation_unit();
|
||||
}
|
||||
|
||||
// so that C++ users can easily add methods
|
||||
void define(const std::string& src, const ResolverPtr& resolver = nullptr);
|
||||
|
||||
template <typename... Types>
|
||||
IValue create_class(const c10::QualifiedName& name, Types&&... args) const {
|
||||
return create_class(name, {IValue(std::forward<Types>(args))...});
|
||||
|
|
@ -369,10 +234,6 @@ struct TORCH_API Module {
|
|||
|
||||
IValue create_class(const c10::QualifiedName& name, Stack stack) const;
|
||||
|
||||
size_t num_slots() const {
|
||||
return module_object()->slots().size();
|
||||
}
|
||||
|
||||
private:
|
||||
Module clone_impl(std::unordered_map<TypePtr, TypePtr>& type_remap) const;
|
||||
|
||||
|
|
@ -382,16 +243,13 @@ struct TORCH_API Module {
|
|||
const std::unordered_map<TypePtr, TypePtr>& type_remap);
|
||||
|
||||
c10::QualifiedName getNameForMethod(std::string basename) const {
|
||||
return QualifiedName(name(), basename);
|
||||
return QualifiedName(*type()->name(), basename);
|
||||
}
|
||||
|
||||
void to_impl(
|
||||
const c10::optional<at::Device>& device,
|
||||
const c10::optional<at::ScalarType>& dtype,
|
||||
bool non_blocking);
|
||||
|
||||
// mutable be we lazily initialize in module_object.
|
||||
mutable ModulePtr module_value_;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
|
@ -460,8 +318,8 @@ struct slot_iterator_impl {
|
|||
return cursors_.back();
|
||||
}
|
||||
IValue cur() const {
|
||||
return return_module() ? top().module_.module_object()
|
||||
: top().module_.module_object()->getSlot(top().i_);
|
||||
return return_module() ? top().module_._ivalue()
|
||||
: top().module_._ivalue()->getSlot(top().i_);
|
||||
}
|
||||
|
||||
// advance to the next slot in a depth first pre-order traversal of the
|
||||
|
|
@ -487,11 +345,7 @@ struct slot_iterator_impl {
|
|||
// if the current thing is a module, we have to scan it for recursive
|
||||
// traversals. We do this by adding a new SlotCursor to track the traversal.
|
||||
if (recurse_ &&
|
||||
top()
|
||||
.module_.module_object()
|
||||
->type()
|
||||
->getAttribute(top().i_)
|
||||
->is_module()) {
|
||||
top().module_._ivalue()->type()->getAttribute(top().i_)->is_module()) {
|
||||
cursors_.emplace_back(SlotCursor{cur().toModule(), 0});
|
||||
return;
|
||||
}
|
||||
|
|
@ -502,7 +356,7 @@ struct slot_iterator_impl {
|
|||
// otherwise, we have to continue advancing.
|
||||
bool valid() const {
|
||||
return top().i_ < int64_t(top().module_.num_slots()) &&
|
||||
Policy::valid(top().module_.module_object()->type(), top().i_);
|
||||
Policy::valid(top().module_._ivalue()->type(), top().i_);
|
||||
}
|
||||
void while_not_valid_next() {
|
||||
// advance iteration until we are either at the end (cursors_.empty())
|
||||
|
|
|
|||
44
torch/csrc/jit/script/object.cpp
Normal file
44
torch/csrc/jit/script/object.cpp
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
#include <torch/csrc/jit/script/object.h>
|
||||
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <torch/csrc/jit/script/compilation_unit.h>
|
||||
#include <torch/csrc/jit/script/resolver.h>
|
||||
#include <torch/csrc/jit/script/sugared_value.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
Object::Object(
|
||||
std::shared_ptr<CompilationUnit> cu,
|
||||
const c10::ClassTypePtr& type)
|
||||
: Object(c10::ivalue::Object::create(
|
||||
c10::StrongTypePtr(std::move(cu), type),
|
||||
type->numAttributes())) {}
|
||||
|
||||
ObjectPtr Object::_ivalue() const {
|
||||
TORCH_INTERNAL_ASSERT(_ivalue_);
|
||||
return _ivalue_;
|
||||
}
|
||||
|
||||
c10::optional<Method> Object::find_method(const std::string& basename) const {
|
||||
for (Function* fn : type()->methods()) {
|
||||
if (fn->name() == basename) {
|
||||
return Method(_ivalue(), fn);
|
||||
}
|
||||
}
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
void Object::define(const std::string& src, const ResolverPtr& resolver) {
|
||||
const auto self = SimpleSelf(type());
|
||||
_ivalue()->compilation_unit()->define(
|
||||
*type()->name(),
|
||||
src,
|
||||
resolver ? resolver : script::nativeResolver(),
|
||||
&self);
|
||||
}
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
110
torch/csrc/jit/script/object.h
Normal file
110
torch/csrc/jit/script/object.h
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <torch/csrc/jit/script/method.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
struct Resolver;
|
||||
using ResolverPtr = std::shared_ptr<Resolver>;
|
||||
|
||||
using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>;
|
||||
|
||||
struct TORCH_API Object {
|
||||
Object() {}
|
||||
Object(ObjectPtr _ivalue) : _ivalue_(std::move(_ivalue)) {}
|
||||
Object(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type);
|
||||
Object(
|
||||
c10::QualifiedName,
|
||||
std::shared_ptr<CompilationUnit> cu,
|
||||
bool shouldMangle = false);
|
||||
|
||||
ObjectPtr _ivalue() const;
|
||||
|
||||
c10::ClassTypePtr type() const {
|
||||
return _ivalue()->type();
|
||||
}
|
||||
|
||||
void setattr(const std::string& name, c10::IValue v) {
|
||||
size_t slot = _ivalue()->type()->getAttributeSlot(name);
|
||||
const c10::TypePtr& expected = _ivalue()->type()->getAttribute(slot);
|
||||
TORCH_CHECK(expected, "Module has no attribute '", name, "'");
|
||||
TORCH_CHECK(
|
||||
v.type()->isSubtypeOf(expected),
|
||||
"Expected a value of type '",
|
||||
expected->python_str(),
|
||||
"' for field '",
|
||||
name,
|
||||
"', but found '",
|
||||
v.type()->python_str(),
|
||||
"'");
|
||||
_ivalue()->setSlot(slot, std::move(v));
|
||||
}
|
||||
|
||||
c10::IValue attr(const std::string& name) const {
|
||||
return _ivalue()->getAttr(name);
|
||||
}
|
||||
|
||||
c10::IValue attr(const std::string& name, c10::IValue or_else) const {
|
||||
if (auto r = _ivalue()->type()->findAttributeSlot(name)) {
|
||||
return _ivalue()->getSlot(*r);
|
||||
}
|
||||
return or_else;
|
||||
}
|
||||
|
||||
bool hasattr(const std::string& name) const {
|
||||
return _ivalue()->type()->findAttributeSlot(name).has_value();
|
||||
}
|
||||
|
||||
// each object owns its methods. The reference returned here
|
||||
// is guarenteed to stay valid until this module has been destroyed
|
||||
Method get_method(const std::string& name) const {
|
||||
if (auto method = find_method(name)) {
|
||||
return *method;
|
||||
}
|
||||
AT_ERROR("Method '", name, "' is not defined.");
|
||||
}
|
||||
|
||||
const std::vector<Method> get_methods() const {
|
||||
return fmap(type()->methods(), [&](Function* func) {
|
||||
return Method(_ivalue(), func);
|
||||
});
|
||||
}
|
||||
|
||||
c10::optional<Method> find_method(const std::string& basename) const;
|
||||
|
||||
/// Run a method from this module.
|
||||
///
|
||||
/// For example:
|
||||
/// @code
|
||||
/// IValue output = module->run("relu_script", a, b);
|
||||
/// @endcode
|
||||
///
|
||||
/// To get a compile a module from a source string, see torch::jit::compile
|
||||
///
|
||||
/// @param method_name The name of the method to run
|
||||
/// @param args Arguments to be passed to the method
|
||||
/// @return An IValue containing the return value (or values if it is a tuple)
|
||||
/// from the method
|
||||
template <typename... Types>
|
||||
IValue run_method(const std::string& method_name, Types&&... args) {
|
||||
return get_method(method_name)({IValue(std::forward<Types>(args))...});
|
||||
}
|
||||
|
||||
// so that C++ users can easily add methods
|
||||
void define(const std::string& src, const ResolverPtr& resolver = nullptr);
|
||||
|
||||
size_t num_slots() const {
|
||||
return _ivalue()->slots().size();
|
||||
}
|
||||
|
||||
private:
|
||||
// mutable be we lazily initialize in module_object.
|
||||
mutable ObjectPtr _ivalue_;
|
||||
};
|
||||
|
||||
} // namespace script
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -290,7 +290,7 @@ static void gatherParametersAndBuffers(
|
|||
const std::string& prefix) {
|
||||
Graph& g = *self_value->owningGraph();
|
||||
|
||||
state->setValue(self.module_object(), self_value);
|
||||
state->setValue(self._ivalue(), self_value);
|
||||
|
||||
auto self_ty = self.type();
|
||||
for (const script::NameValue& s : self.named_attributes(/*recurse=*/false)) {
|
||||
|
|
@ -328,8 +328,8 @@ std::pair<std::shared_ptr<TracingState>, Stack> trace(
|
|||
// if we are a module, then make sure the modules parameters are in the map
|
||||
// and mapped to accesses to the self object
|
||||
if (self) {
|
||||
Value* self_value =
|
||||
state->graph->insertInput(0, "self")->setType(self->module_object()->type());
|
||||
Value* self_value = state->graph->insertInput(0, "self")->setType(
|
||||
self->_ivalue()->type());
|
||||
gatherParametersAndBuffers(state, self_value, *self, {"__module"});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ ScriptModuleOutput ScriptModuleBenchmark::runOnce(
|
|||
function.getSchema(),
|
||||
std::move(args),
|
||||
std::move(kwargs),
|
||||
model_.module_object());
|
||||
model_._ivalue());
|
||||
return function(std::move(stack));
|
||||
}
|
||||
|
||||
|
|
@ -100,7 +100,7 @@ void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs) {
|
|||
model_.get_method("forward").function().getSchema(),
|
||||
std::move(args),
|
||||
std::move(kwargs),
|
||||
model_.module_object());
|
||||
model_._ivalue());
|
||||
inputs_.emplace_back(std::move(stack));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ namespace detail {
|
|||
template <class Input, class Output, class Model>
|
||||
class BenchmarkHelper {
|
||||
public:
|
||||
BenchmarkHelper(): initialized_{false} {}
|
||||
BenchmarkHelper();
|
||||
explicit BenchmarkHelper(Model model): model_(model), initialized_(true) {}
|
||||
|
||||
// This method to be used in benchmark() method
|
||||
|
|
@ -108,7 +108,14 @@ typedef BenchmarkHelper<
|
|||
at::IValue,
|
||||
jit::script::Module>
|
||||
ScriptModuleBenchmark;
|
||||
template <>
|
||||
inline BenchmarkHelper<ScriptModuleInput, at::IValue, jit::script::Module>::BenchmarkHelper()
|
||||
: model_("Module", std::make_shared<jit::script::CompilationUnit>()),
|
||||
initialized_(false) {}
|
||||
typedef BenchmarkHelper<ModuleInput, py::object, py::object> ModuleBenchmark;
|
||||
template <>
|
||||
inline BenchmarkHelper<ModuleInput, py::object, py::object>::BenchmarkHelper()
|
||||
: initialized_(false) {}
|
||||
|
||||
template <>
|
||||
void ScriptModuleBenchmark::runOnce(
|
||||
|
|
|
|||
|
|
@ -1627,9 +1627,9 @@ if _enabled:
|
|||
|
||||
@property
|
||||
def original_name(self):
|
||||
if type(self) == self._c.name:
|
||||
if type(self) == str(self._c._type().name()):
|
||||
return ''
|
||||
return self._c.name
|
||||
return str(self._c._type().name())
|
||||
|
||||
def define(self, src):
|
||||
# We use frames_up=1 to get to the proper surrounding scope. The stack
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user