mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
refactoring of module/object (#22203)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22203 ghimport-source-id: 6b3807ac8aa53df2fdd770b43d8e54b8f0d69c20 Test Plan: Imported from OSS Differential Revision: D15998760 Pulled By: suo fbshipit-source-id: dd51edbcb66561189ae9d94a129434092bcad01b
This commit is contained in:
parent
3b2844eeea
commit
f5919dba45
|
|
@ -129,16 +129,16 @@ void IValue::dump() const {
|
|||
|
||||
|
||||
std::string ivalue::Object::name() const {
|
||||
return this->type_->qualname();
|
||||
return this->type_.type_->qualname();
|
||||
}
|
||||
|
||||
IValue ivalue::Object::getAttr(const std::string& name) const {
|
||||
const size_t slot = type_->getAttributeSlot(name);
|
||||
const size_t slot = type_.type_->getAttributeSlot(name);
|
||||
return getSlot(slot);
|
||||
}
|
||||
|
||||
void ivalue::Object::setAttr(const std::string& name, IValue v) {
|
||||
const size_t slot = type_->getAttributeSlot(name);
|
||||
const size_t slot = type_.type_->getAttributeSlot(name);
|
||||
setSlot(slot, std::move(v));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,12 +7,16 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
struct Function;
|
||||
namespace script {
|
||||
struct CompilationUnit;
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
namespace c10 {
|
||||
template<class Key, class Value> class Dict;
|
||||
template<class T> class List;
|
||||
struct IValue;
|
||||
struct ClassType;
|
||||
namespace ivalue {
|
||||
struct Tuple;
|
||||
struct Future;
|
||||
|
|
@ -546,6 +550,20 @@ private:
|
|||
bool is_intrusive_ptr;
|
||||
};
|
||||
|
||||
// An owning pointer to a Class. Just a pair of shared_ptrs to the class type
|
||||
// and its owning CU, so that the class type is guaranteed to stay alive as long
|
||||
// as we hold this object.
|
||||
struct StrongTypePtr {
|
||||
StrongTypePtr(
|
||||
std::shared_ptr<torch::jit::script::CompilationUnit> cu,
|
||||
std::shared_ptr<ClassType> type)
|
||||
: cu_(std::move(cu)), type_(type) {
|
||||
TORCH_INTERNAL_ASSERT(cu_);
|
||||
TORCH_INTERNAL_ASSERT(type_);
|
||||
}
|
||||
std::shared_ptr<torch::jit::script::CompilationUnit> cu_;
|
||||
std::shared_ptr<ClassType> type_;
|
||||
};
|
||||
}
|
||||
|
||||
#include <ATen/core/ivalue_inl.h>
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
struct Function;
|
||||
namespace script {
|
||||
struct CompilationUnit;
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
namespace c10 {
|
||||
|
|
@ -268,16 +271,20 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
|
|||
// temporary way to break cyclic dependencies in modules by forcing the deletion
|
||||
// of functions when the module object is destructed
|
||||
typedef void (*OnDelete)(ivalue::Object*);
|
||||
Object(std::shared_ptr<ClassType> type, size_t numSlots, OnDelete on_delete)
|
||||
Object(
|
||||
StrongTypePtr type,
|
||||
size_t numSlots,
|
||||
OnDelete on_delete)
|
||||
: type_(std::move(type)), on_delete_(on_delete) {
|
||||
slots_.resize(numSlots);
|
||||
}
|
||||
|
||||
static c10::intrusive_ptr<Object> create(
|
||||
std::shared_ptr<ClassType> type,
|
||||
StrongTypePtr type,
|
||||
size_t numSlots,
|
||||
OnDelete on_delete = nullptr) {
|
||||
return c10::make_intrusive<Object>(std::move(type), numSlots, on_delete);
|
||||
return c10::make_intrusive<Object>(
|
||||
std::move(type), numSlots, on_delete);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -321,14 +328,19 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
|
|||
return slots_;
|
||||
}
|
||||
std::shared_ptr<ClassType> type() const {
|
||||
return type_;
|
||||
return type_.type_;
|
||||
}
|
||||
|
||||
std::shared_ptr<torch::jit::script::CompilationUnit> compilation_unit() {
|
||||
return type_.cu_;
|
||||
}
|
||||
// temporarily defined in class_type.cpp to
|
||||
// ensure Modules do not leak memory
|
||||
~Object();
|
||||
|
||||
private:
|
||||
void resizeObject(size_t slot);
|
||||
std::shared_ptr<ClassType> type_;
|
||||
StrongTypePtr type_;
|
||||
std::vector<IValue> slots_;
|
||||
OnDelete on_delete_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -431,7 +431,9 @@ inline IValue toIValue(
|
|||
auto classType = type->expect<ClassType>();
|
||||
// 1. create a bare ivalue
|
||||
const size_t numAttrs = classType->numAttributes();
|
||||
auto userObj = c10::ivalue::Object::create(classType, numAttrs);
|
||||
auto userObj = c10::ivalue::Object::create(
|
||||
c10::StrongTypePtr(classType->compilation_unit(), classType),
|
||||
numAttrs);
|
||||
|
||||
// 2. copy all the contained types
|
||||
for (size_t slot = 0; slot < numAttrs; slot++) {
|
||||
|
|
@ -574,8 +576,8 @@ inline py::object toPyObject(IValue&& ivalue) {
|
|||
return std::move(py_dict);
|
||||
} else if (ivalue.isObject()) {
|
||||
const auto obj = std::move(ivalue).toObject();
|
||||
auto& pyCu = script::CompilationUnit::_get_python_cu();
|
||||
const auto classType = pyCu.get_class(c10::QualifiedName(obj->name()));
|
||||
auto pyCu = script::CompilationUnit::_get_python_cu();
|
||||
const auto classType = pyCu->get_class(c10::QualifiedName(obj->name()));
|
||||
AT_ASSERT(classType);
|
||||
auto pyClass =
|
||||
py::module::import("torch.jit").attr("_get_script_class")(obj->name());
|
||||
|
|
|
|||
|
|
@ -1074,7 +1074,8 @@ RegisterOperators reg(
|
|||
const auto type = node->output()->type()->expect<ClassType>();
|
||||
const size_t numAttrs = type->numAttributes();
|
||||
return [type, numAttrs](Stack& stack) {
|
||||
auto userObj = c10::ivalue::Object::create(type, numAttrs);
|
||||
auto userObj = c10::ivalue::Object::create(
|
||||
c10::StrongTypePtr(type->compilation_unit(), type), numAttrs);
|
||||
push(stack, std::move(userObj));
|
||||
return 0;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -162,17 +162,17 @@ struct TORCH_API CompilationUnit {
|
|||
* Right now there is a single compilation unit that owns all ScriptClasses
|
||||
* defined in Python. Below are accessors methods for it.
|
||||
*/
|
||||
static const CompilationUnit& _get_python_cu_const() {
|
||||
static std::shared_ptr<CompilationUnit> _get_python_cu_const() {
|
||||
return _get_python_cu();
|
||||
}
|
||||
static CompilationUnit& _get_python_cu() {
|
||||
static CompilationUnit pyCu;
|
||||
static std::shared_ptr<CompilationUnit> _get_python_cu() {
|
||||
static auto pyCu = std::make_shared<CompilationUnit>();
|
||||
return pyCu;
|
||||
}
|
||||
// For testing: clear all Python-defined classes to ensure that unit tests
|
||||
// have isolation.
|
||||
static void _clear_python_cu() {
|
||||
_get_python_cu().classes_.clear();
|
||||
_get_python_cu()->classes_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -148,11 +148,11 @@ struct PythonResolver : public Resolver {
|
|||
annotations,
|
||||
qualifiedName,
|
||||
TupleType::namedTupleSchemaFromNamesAndTypes(qualifiedName, fields, annotations));
|
||||
CompilationUnit::_get_python_cu().register_class(tt);
|
||||
CompilationUnit::_get_python_cu()->register_class(tt);
|
||||
return tt;
|
||||
}
|
||||
|
||||
return CompilationUnit::_get_python_cu().get_class(qualifiedName);
|
||||
return CompilationUnit::_get_python_cu()->get_class(qualifiedName);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -321,8 +321,7 @@ void addFunctionToModule(Module& module, const StrongFunctionPtr& func) {
|
|||
auto graph = func.function_->graph()->copy();
|
||||
auto v = graph->insertInput(0, "self");
|
||||
v->setType(module.module_object()->type());
|
||||
module.module_object()->type()->compilation_unit()->create_function(
|
||||
"forward", graph);
|
||||
module.module_object()->compilation_unit()->create_function("forward", graph);
|
||||
}
|
||||
|
||||
void initJitScriptBindings(PyObject* module) {
|
||||
|
|
@ -707,7 +706,7 @@ void initJitScriptBindings(PyObject* module) {
|
|||
auto cu = std::make_shared<CompilationUnit>();
|
||||
auto classType =
|
||||
ClassType::create(c10::QualifiedName(qualifiedName), cu);
|
||||
CompilationUnit::_get_python_cu().register_class(classType);
|
||||
CompilationUnit::_get_python_cu()->register_class(classType);
|
||||
std::vector<ResolverPtr> rcbs;
|
||||
std::vector<Def> methodDefs;
|
||||
for (const auto& def : classDef.defs()) {
|
||||
|
|
@ -761,7 +760,7 @@ void initJitScriptBindings(PyObject* module) {
|
|||
const std::vector<at::Tensor>& constant_table,
|
||||
const Self& self) {
|
||||
import_functions(
|
||||
CompilationUnit::_get_python_cu_const(),
|
||||
*CompilationUnit::_get_python_cu_const(),
|
||||
cu,
|
||||
std::make_shared<Source>(src),
|
||||
constant_table,
|
||||
|
|
|
|||
|
|
@ -192,6 +192,10 @@ std::pair<std::shared_ptr<Graph>, std::vector<at::Tensor>> Method::_lowered_grap
|
|||
return std::make_pair(result.first, loadTensors(result.second));
|
||||
}
|
||||
|
||||
static void clearMethods(c10::ivalue::Object* self) {
|
||||
self->compilation_unit()->drop_all_functions();
|
||||
}
|
||||
|
||||
void Module::define(const std::string& src, const ResolverPtr& resolver) {
|
||||
class_compilation_unit()->define(
|
||||
src,
|
||||
|
|
@ -292,7 +296,8 @@ IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const {
|
|||
|
||||
// Create a bare object with correct number of slots
|
||||
const size_t numAttrs = classType->numAttributes();
|
||||
auto obj = c10::ivalue::Object::create(classType, numAttrs);
|
||||
auto obj = c10::ivalue::Object::create(
|
||||
c10::StrongTypePtr(class_compilation_unit(), classType), numAttrs);
|
||||
|
||||
// Invoke the `__init__()` of the class with the arguments provided.
|
||||
Stack stackWithSelf = {obj};
|
||||
|
|
|
|||
|
|
@ -106,9 +106,6 @@ struct TORCH_API Method {
|
|||
// first-class function in class_compilation_unit()
|
||||
Function* function_;
|
||||
};
|
||||
static void clearMethods(c10::ivalue::Object* self) {
|
||||
self->type()->compilation_unit()->drop_all_functions();
|
||||
}
|
||||
|
||||
struct TORCH_API Module {
|
||||
Module(std::string class_name)
|
||||
|
|
@ -306,11 +303,6 @@ struct TORCH_API Module {
|
|||
std::unordered_map<TypePtr, TypePtr>& type_remap,
|
||||
std::vector<std::string> names = {}) const;
|
||||
|
||||
void clone_method(
|
||||
const Module& orig,
|
||||
const std::string& name,
|
||||
const std::unordered_map<TypePtr, TypePtr>& type_remap);
|
||||
|
||||
void clone_method(const Module& orig, const std::string& name);
|
||||
|
||||
at::optional<EntityType> kind_of(const std::string& name) const {
|
||||
|
|
@ -334,11 +326,8 @@ struct TORCH_API Module {
|
|||
ClassTypePtr type() const {
|
||||
return module_object()->type();
|
||||
}
|
||||
std::shared_ptr<CompilationUnit> class_compilation_unit() {
|
||||
return module_object()->type()->compilation_unit();
|
||||
}
|
||||
std::shared_ptr<const CompilationUnit> class_compilation_unit() const {
|
||||
return module_object()->type()->compilation_unit();
|
||||
std::shared_ptr<CompilationUnit> class_compilation_unit() const {
|
||||
return module_object()->compilation_unit();
|
||||
}
|
||||
|
||||
// so that C++ users can easily add methods
|
||||
|
|
@ -362,6 +351,10 @@ struct TORCH_API Module {
|
|||
}
|
||||
|
||||
private:
|
||||
void clone_method(
|
||||
const Module& orig,
|
||||
const std::string& name,
|
||||
const std::unordered_map<TypePtr, TypePtr>& type_remap);
|
||||
static const char* toString(EntityType t) {
|
||||
switch (t) {
|
||||
case EntityType::MODULE:
|
||||
|
|
@ -428,14 +421,15 @@ struct TORCH_API Module {
|
|||
const c10::optional<at::ScalarType>& dtype,
|
||||
bool non_blocking);
|
||||
|
||||
static void clearMethods(c10::ivalue::Object* self) {
|
||||
self->compilation_unit()->drop_all_functions();
|
||||
}
|
||||
static ModulePtr create_module_object(std::string class_name) {
|
||||
auto cu = std::make_shared<CompilationUnit>();
|
||||
auto cls = ClassType::create(
|
||||
QualifiedName(std::move(class_name)), cu, /*is_module=*/true);
|
||||
return c10::ivalue::Object::create(
|
||||
ClassType::create(
|
||||
QualifiedName(std::move(class_name)),
|
||||
std::make_shared<CompilationUnit>(),
|
||||
/*is_module=*/true),
|
||||
0,
|
||||
clearMethods);
|
||||
c10::StrongTypePtr(std::move(cu), std::move(cls)), 0, clearMethods);
|
||||
}
|
||||
// mutable be we lazily initialize in module_object.
|
||||
mutable ModulePtr module_value_;
|
||||
|
|
|
|||
|
|
@ -516,8 +516,8 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
|||
if (py::cast<bool>(isClass)) {
|
||||
py::str qualifiedName =
|
||||
py::module::import("torch.jit").attr("_qualified_name")(obj);
|
||||
auto& pyCu = CompilationUnit::_get_python_cu();
|
||||
if (auto classType = pyCu.get_class(c10::QualifiedName(qualifiedName))) {
|
||||
auto pyCu = CompilationUnit::_get_python_cu();
|
||||
if (auto classType = pyCu->get_class(c10::QualifiedName(qualifiedName))) {
|
||||
return std::make_shared<ClassValue>(classType);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user