mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
refactoring of module/object (#22203)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22203 ghimport-source-id: 6b3807ac8aa53df2fdd770b43d8e54b8f0d69c20 Test Plan: Imported from OSS Differential Revision: D15998760 Pulled By: suo fbshipit-source-id: dd51edbcb66561189ae9d94a129434092bcad01b
This commit is contained in:
parent
3b2844eeea
commit
f5919dba45
|
|
@ -129,16 +129,16 @@ void IValue::dump() const {
|
||||||
|
|
||||||
|
|
||||||
std::string ivalue::Object::name() const {
|
std::string ivalue::Object::name() const {
|
||||||
return this->type_->qualname();
|
return this->type_.type_->qualname();
|
||||||
}
|
}
|
||||||
|
|
||||||
IValue ivalue::Object::getAttr(const std::string& name) const {
|
IValue ivalue::Object::getAttr(const std::string& name) const {
|
||||||
const size_t slot = type_->getAttributeSlot(name);
|
const size_t slot = type_.type_->getAttributeSlot(name);
|
||||||
return getSlot(slot);
|
return getSlot(slot);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ivalue::Object::setAttr(const std::string& name, IValue v) {
|
void ivalue::Object::setAttr(const std::string& name, IValue v) {
|
||||||
const size_t slot = type_->getAttributeSlot(name);
|
const size_t slot = type_.type_->getAttributeSlot(name);
|
||||||
setSlot(slot, std::move(v));
|
setSlot(slot, std::move(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,12 +7,16 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
struct Function;
|
struct Function;
|
||||||
|
namespace script {
|
||||||
|
struct CompilationUnit;
|
||||||
|
}
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
template<class Key, class Value> class Dict;
|
template<class Key, class Value> class Dict;
|
||||||
template<class T> class List;
|
template<class T> class List;
|
||||||
struct IValue;
|
struct IValue;
|
||||||
|
struct ClassType;
|
||||||
namespace ivalue {
|
namespace ivalue {
|
||||||
struct Tuple;
|
struct Tuple;
|
||||||
struct Future;
|
struct Future;
|
||||||
|
|
@ -546,6 +550,20 @@ private:
|
||||||
bool is_intrusive_ptr;
|
bool is_intrusive_ptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// An owning pointer to a Class. Just a pair of shared_ptrs to the class type
|
||||||
|
// and its owning CU, so that the class type is guaranteed to stay alive as long
|
||||||
|
// as we hold this object.
|
||||||
|
struct StrongTypePtr {
|
||||||
|
StrongTypePtr(
|
||||||
|
std::shared_ptr<torch::jit::script::CompilationUnit> cu,
|
||||||
|
std::shared_ptr<ClassType> type)
|
||||||
|
: cu_(std::move(cu)), type_(type) {
|
||||||
|
TORCH_INTERNAL_ASSERT(cu_);
|
||||||
|
TORCH_INTERNAL_ASSERT(type_);
|
||||||
|
}
|
||||||
|
std::shared_ptr<torch::jit::script::CompilationUnit> cu_;
|
||||||
|
std::shared_ptr<ClassType> type_;
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
#include <ATen/core/ivalue_inl.h>
|
#include <ATen/core/ivalue_inl.h>
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,9 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
struct Function;
|
struct Function;
|
||||||
|
namespace script {
|
||||||
|
struct CompilationUnit;
|
||||||
|
}
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
@ -268,16 +271,20 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
|
||||||
// temporary way to break cyclic dependencies in modules by forcing the deletion
|
// temporary way to break cyclic dependencies in modules by forcing the deletion
|
||||||
// of functions when the module object is destructed
|
// of functions when the module object is destructed
|
||||||
typedef void (*OnDelete)(ivalue::Object*);
|
typedef void (*OnDelete)(ivalue::Object*);
|
||||||
Object(std::shared_ptr<ClassType> type, size_t numSlots, OnDelete on_delete)
|
Object(
|
||||||
|
StrongTypePtr type,
|
||||||
|
size_t numSlots,
|
||||||
|
OnDelete on_delete)
|
||||||
: type_(std::move(type)), on_delete_(on_delete) {
|
: type_(std::move(type)), on_delete_(on_delete) {
|
||||||
slots_.resize(numSlots);
|
slots_.resize(numSlots);
|
||||||
}
|
}
|
||||||
|
|
||||||
static c10::intrusive_ptr<Object> create(
|
static c10::intrusive_ptr<Object> create(
|
||||||
std::shared_ptr<ClassType> type,
|
StrongTypePtr type,
|
||||||
size_t numSlots,
|
size_t numSlots,
|
||||||
OnDelete on_delete = nullptr) {
|
OnDelete on_delete = nullptr) {
|
||||||
return c10::make_intrusive<Object>(std::move(type), numSlots, on_delete);
|
return c10::make_intrusive<Object>(
|
||||||
|
std::move(type), numSlots, on_delete);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -321,14 +328,19 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
|
||||||
return slots_;
|
return slots_;
|
||||||
}
|
}
|
||||||
std::shared_ptr<ClassType> type() const {
|
std::shared_ptr<ClassType> type() const {
|
||||||
return type_;
|
return type_.type_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<torch::jit::script::CompilationUnit> compilation_unit() {
|
||||||
|
return type_.cu_;
|
||||||
}
|
}
|
||||||
// temporarily defined in class_type.cpp to
|
// temporarily defined in class_type.cpp to
|
||||||
// ensure Modules do not leak memory
|
// ensure Modules do not leak memory
|
||||||
~Object();
|
~Object();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void resizeObject(size_t slot);
|
void resizeObject(size_t slot);
|
||||||
std::shared_ptr<ClassType> type_;
|
StrongTypePtr type_;
|
||||||
std::vector<IValue> slots_;
|
std::vector<IValue> slots_;
|
||||||
OnDelete on_delete_;
|
OnDelete on_delete_;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -431,7 +431,9 @@ inline IValue toIValue(
|
||||||
auto classType = type->expect<ClassType>();
|
auto classType = type->expect<ClassType>();
|
||||||
// 1. create a bare ivalue
|
// 1. create a bare ivalue
|
||||||
const size_t numAttrs = classType->numAttributes();
|
const size_t numAttrs = classType->numAttributes();
|
||||||
auto userObj = c10::ivalue::Object::create(classType, numAttrs);
|
auto userObj = c10::ivalue::Object::create(
|
||||||
|
c10::StrongTypePtr(classType->compilation_unit(), classType),
|
||||||
|
numAttrs);
|
||||||
|
|
||||||
// 2. copy all the contained types
|
// 2. copy all the contained types
|
||||||
for (size_t slot = 0; slot < numAttrs; slot++) {
|
for (size_t slot = 0; slot < numAttrs; slot++) {
|
||||||
|
|
@ -574,8 +576,8 @@ inline py::object toPyObject(IValue&& ivalue) {
|
||||||
return std::move(py_dict);
|
return std::move(py_dict);
|
||||||
} else if (ivalue.isObject()) {
|
} else if (ivalue.isObject()) {
|
||||||
const auto obj = std::move(ivalue).toObject();
|
const auto obj = std::move(ivalue).toObject();
|
||||||
auto& pyCu = script::CompilationUnit::_get_python_cu();
|
auto pyCu = script::CompilationUnit::_get_python_cu();
|
||||||
const auto classType = pyCu.get_class(c10::QualifiedName(obj->name()));
|
const auto classType = pyCu->get_class(c10::QualifiedName(obj->name()));
|
||||||
AT_ASSERT(classType);
|
AT_ASSERT(classType);
|
||||||
auto pyClass =
|
auto pyClass =
|
||||||
py::module::import("torch.jit").attr("_get_script_class")(obj->name());
|
py::module::import("torch.jit").attr("_get_script_class")(obj->name());
|
||||||
|
|
|
||||||
|
|
@ -1074,7 +1074,8 @@ RegisterOperators reg(
|
||||||
const auto type = node->output()->type()->expect<ClassType>();
|
const auto type = node->output()->type()->expect<ClassType>();
|
||||||
const size_t numAttrs = type->numAttributes();
|
const size_t numAttrs = type->numAttributes();
|
||||||
return [type, numAttrs](Stack& stack) {
|
return [type, numAttrs](Stack& stack) {
|
||||||
auto userObj = c10::ivalue::Object::create(type, numAttrs);
|
auto userObj = c10::ivalue::Object::create(
|
||||||
|
c10::StrongTypePtr(type->compilation_unit(), type), numAttrs);
|
||||||
push(stack, std::move(userObj));
|
push(stack, std::move(userObj));
|
||||||
return 0;
|
return 0;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -162,17 +162,17 @@ struct TORCH_API CompilationUnit {
|
||||||
* Right now there is a single compilation unit that owns all ScriptClasses
|
* Right now there is a single compilation unit that owns all ScriptClasses
|
||||||
* defined in Python. Below are accessors methods for it.
|
* defined in Python. Below are accessors methods for it.
|
||||||
*/
|
*/
|
||||||
static const CompilationUnit& _get_python_cu_const() {
|
static std::shared_ptr<CompilationUnit> _get_python_cu_const() {
|
||||||
return _get_python_cu();
|
return _get_python_cu();
|
||||||
}
|
}
|
||||||
static CompilationUnit& _get_python_cu() {
|
static std::shared_ptr<CompilationUnit> _get_python_cu() {
|
||||||
static CompilationUnit pyCu;
|
static auto pyCu = std::make_shared<CompilationUnit>();
|
||||||
return pyCu;
|
return pyCu;
|
||||||
}
|
}
|
||||||
// For testing: clear all Python-defined classes to ensure that unit tests
|
// For testing: clear all Python-defined classes to ensure that unit tests
|
||||||
// have isolation.
|
// have isolation.
|
||||||
static void _clear_python_cu() {
|
static void _clear_python_cu() {
|
||||||
_get_python_cu().classes_.clear();
|
_get_python_cu()->classes_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
||||||
|
|
@ -148,11 +148,11 @@ struct PythonResolver : public Resolver {
|
||||||
annotations,
|
annotations,
|
||||||
qualifiedName,
|
qualifiedName,
|
||||||
TupleType::namedTupleSchemaFromNamesAndTypes(qualifiedName, fields, annotations));
|
TupleType::namedTupleSchemaFromNamesAndTypes(qualifiedName, fields, annotations));
|
||||||
CompilationUnit::_get_python_cu().register_class(tt);
|
CompilationUnit::_get_python_cu()->register_class(tt);
|
||||||
return tt;
|
return tt;
|
||||||
}
|
}
|
||||||
|
|
||||||
return CompilationUnit::_get_python_cu().get_class(qualifiedName);
|
return CompilationUnit::_get_python_cu()->get_class(qualifiedName);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
@ -321,8 +321,7 @@ void addFunctionToModule(Module& module, const StrongFunctionPtr& func) {
|
||||||
auto graph = func.function_->graph()->copy();
|
auto graph = func.function_->graph()->copy();
|
||||||
auto v = graph->insertInput(0, "self");
|
auto v = graph->insertInput(0, "self");
|
||||||
v->setType(module.module_object()->type());
|
v->setType(module.module_object()->type());
|
||||||
module.module_object()->type()->compilation_unit()->create_function(
|
module.module_object()->compilation_unit()->create_function("forward", graph);
|
||||||
"forward", graph);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void initJitScriptBindings(PyObject* module) {
|
void initJitScriptBindings(PyObject* module) {
|
||||||
|
|
@ -707,7 +706,7 @@ void initJitScriptBindings(PyObject* module) {
|
||||||
auto cu = std::make_shared<CompilationUnit>();
|
auto cu = std::make_shared<CompilationUnit>();
|
||||||
auto classType =
|
auto classType =
|
||||||
ClassType::create(c10::QualifiedName(qualifiedName), cu);
|
ClassType::create(c10::QualifiedName(qualifiedName), cu);
|
||||||
CompilationUnit::_get_python_cu().register_class(classType);
|
CompilationUnit::_get_python_cu()->register_class(classType);
|
||||||
std::vector<ResolverPtr> rcbs;
|
std::vector<ResolverPtr> rcbs;
|
||||||
std::vector<Def> methodDefs;
|
std::vector<Def> methodDefs;
|
||||||
for (const auto& def : classDef.defs()) {
|
for (const auto& def : classDef.defs()) {
|
||||||
|
|
@ -761,7 +760,7 @@ void initJitScriptBindings(PyObject* module) {
|
||||||
const std::vector<at::Tensor>& constant_table,
|
const std::vector<at::Tensor>& constant_table,
|
||||||
const Self& self) {
|
const Self& self) {
|
||||||
import_functions(
|
import_functions(
|
||||||
CompilationUnit::_get_python_cu_const(),
|
*CompilationUnit::_get_python_cu_const(),
|
||||||
cu,
|
cu,
|
||||||
std::make_shared<Source>(src),
|
std::make_shared<Source>(src),
|
||||||
constant_table,
|
constant_table,
|
||||||
|
|
|
||||||
|
|
@ -192,6 +192,10 @@ std::pair<std::shared_ptr<Graph>, std::vector<at::Tensor>> Method::_lowered_grap
|
||||||
return std::make_pair(result.first, loadTensors(result.second));
|
return std::make_pair(result.first, loadTensors(result.second));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void clearMethods(c10::ivalue::Object* self) {
|
||||||
|
self->compilation_unit()->drop_all_functions();
|
||||||
|
}
|
||||||
|
|
||||||
void Module::define(const std::string& src, const ResolverPtr& resolver) {
|
void Module::define(const std::string& src, const ResolverPtr& resolver) {
|
||||||
class_compilation_unit()->define(
|
class_compilation_unit()->define(
|
||||||
src,
|
src,
|
||||||
|
|
@ -292,7 +296,8 @@ IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const {
|
||||||
|
|
||||||
// Create a bare object with correct number of slots
|
// Create a bare object with correct number of slots
|
||||||
const size_t numAttrs = classType->numAttributes();
|
const size_t numAttrs = classType->numAttributes();
|
||||||
auto obj = c10::ivalue::Object::create(classType, numAttrs);
|
auto obj = c10::ivalue::Object::create(
|
||||||
|
c10::StrongTypePtr(class_compilation_unit(), classType), numAttrs);
|
||||||
|
|
||||||
// Invoke the `__init__()` of the class with the arguments provided.
|
// Invoke the `__init__()` of the class with the arguments provided.
|
||||||
Stack stackWithSelf = {obj};
|
Stack stackWithSelf = {obj};
|
||||||
|
|
|
||||||
|
|
@ -106,9 +106,6 @@ struct TORCH_API Method {
|
||||||
// first-class function in class_compilation_unit()
|
// first-class function in class_compilation_unit()
|
||||||
Function* function_;
|
Function* function_;
|
||||||
};
|
};
|
||||||
static void clearMethods(c10::ivalue::Object* self) {
|
|
||||||
self->type()->compilation_unit()->drop_all_functions();
|
|
||||||
}
|
|
||||||
|
|
||||||
struct TORCH_API Module {
|
struct TORCH_API Module {
|
||||||
Module(std::string class_name)
|
Module(std::string class_name)
|
||||||
|
|
@ -306,11 +303,6 @@ struct TORCH_API Module {
|
||||||
std::unordered_map<TypePtr, TypePtr>& type_remap,
|
std::unordered_map<TypePtr, TypePtr>& type_remap,
|
||||||
std::vector<std::string> names = {}) const;
|
std::vector<std::string> names = {}) const;
|
||||||
|
|
||||||
void clone_method(
|
|
||||||
const Module& orig,
|
|
||||||
const std::string& name,
|
|
||||||
const std::unordered_map<TypePtr, TypePtr>& type_remap);
|
|
||||||
|
|
||||||
void clone_method(const Module& orig, const std::string& name);
|
void clone_method(const Module& orig, const std::string& name);
|
||||||
|
|
||||||
at::optional<EntityType> kind_of(const std::string& name) const {
|
at::optional<EntityType> kind_of(const std::string& name) const {
|
||||||
|
|
@ -334,11 +326,8 @@ struct TORCH_API Module {
|
||||||
ClassTypePtr type() const {
|
ClassTypePtr type() const {
|
||||||
return module_object()->type();
|
return module_object()->type();
|
||||||
}
|
}
|
||||||
std::shared_ptr<CompilationUnit> class_compilation_unit() {
|
std::shared_ptr<CompilationUnit> class_compilation_unit() const {
|
||||||
return module_object()->type()->compilation_unit();
|
return module_object()->compilation_unit();
|
||||||
}
|
|
||||||
std::shared_ptr<const CompilationUnit> class_compilation_unit() const {
|
|
||||||
return module_object()->type()->compilation_unit();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// so that C++ users can easily add methods
|
// so that C++ users can easily add methods
|
||||||
|
|
@ -362,6 +351,10 @@ struct TORCH_API Module {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void clone_method(
|
||||||
|
const Module& orig,
|
||||||
|
const std::string& name,
|
||||||
|
const std::unordered_map<TypePtr, TypePtr>& type_remap);
|
||||||
static const char* toString(EntityType t) {
|
static const char* toString(EntityType t) {
|
||||||
switch (t) {
|
switch (t) {
|
||||||
case EntityType::MODULE:
|
case EntityType::MODULE:
|
||||||
|
|
@ -428,14 +421,15 @@ struct TORCH_API Module {
|
||||||
const c10::optional<at::ScalarType>& dtype,
|
const c10::optional<at::ScalarType>& dtype,
|
||||||
bool non_blocking);
|
bool non_blocking);
|
||||||
|
|
||||||
|
static void clearMethods(c10::ivalue::Object* self) {
|
||||||
|
self->compilation_unit()->drop_all_functions();
|
||||||
|
}
|
||||||
static ModulePtr create_module_object(std::string class_name) {
|
static ModulePtr create_module_object(std::string class_name) {
|
||||||
|
auto cu = std::make_shared<CompilationUnit>();
|
||||||
|
auto cls = ClassType::create(
|
||||||
|
QualifiedName(std::move(class_name)), cu, /*is_module=*/true);
|
||||||
return c10::ivalue::Object::create(
|
return c10::ivalue::Object::create(
|
||||||
ClassType::create(
|
c10::StrongTypePtr(std::move(cu), std::move(cls)), 0, clearMethods);
|
||||||
QualifiedName(std::move(class_name)),
|
|
||||||
std::make_shared<CompilationUnit>(),
|
|
||||||
/*is_module=*/true),
|
|
||||||
0,
|
|
||||||
clearMethods);
|
|
||||||
}
|
}
|
||||||
// mutable be we lazily initialize in module_object.
|
// mutable be we lazily initialize in module_object.
|
||||||
mutable ModulePtr module_value_;
|
mutable ModulePtr module_value_;
|
||||||
|
|
|
||||||
|
|
@ -516,8 +516,8 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
||||||
if (py::cast<bool>(isClass)) {
|
if (py::cast<bool>(isClass)) {
|
||||||
py::str qualifiedName =
|
py::str qualifiedName =
|
||||||
py::module::import("torch.jit").attr("_qualified_name")(obj);
|
py::module::import("torch.jit").attr("_qualified_name")(obj);
|
||||||
auto& pyCu = CompilationUnit::_get_python_cu();
|
auto pyCu = CompilationUnit::_get_python_cu();
|
||||||
if (auto classType = pyCu.get_class(c10::QualifiedName(qualifiedName))) {
|
if (auto classType = pyCu->get_class(c10::QualifiedName(qualifiedName))) {
|
||||||
return std::make_shared<ClassValue>(classType);
|
return std::make_shared<ClassValue>(classType);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user