refactoring of module/object (#22203)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22203
ghimport-source-id: 6b3807ac8aa53df2fdd770b43d8e54b8f0d69c20

Test Plan: Imported from OSS

Differential Revision: D15998760

Pulled By: suo

fbshipit-source-id: dd51edbcb66561189ae9d94a129434092bcad01b
This commit is contained in:
Michael Suo 2019-07-04 17:07:52 -07:00 committed by Facebook Github Bot
parent 3b2844eeea
commit f5919dba45
10 changed files with 75 additions and 44 deletions

View File

@ -129,16 +129,16 @@ void IValue::dump() const {
std::string ivalue::Object::name() const { std::string ivalue::Object::name() const {
return this->type_->qualname(); return this->type_.type_->qualname();
} }
IValue ivalue::Object::getAttr(const std::string& name) const { IValue ivalue::Object::getAttr(const std::string& name) const {
const size_t slot = type_->getAttributeSlot(name); const size_t slot = type_.type_->getAttributeSlot(name);
return getSlot(slot); return getSlot(slot);
} }
void ivalue::Object::setAttr(const std::string& name, IValue v) { void ivalue::Object::setAttr(const std::string& name, IValue v) {
const size_t slot = type_->getAttributeSlot(name); const size_t slot = type_.type_->getAttributeSlot(name);
setSlot(slot, std::move(v)); setSlot(slot, std::move(v));
} }

View File

@ -7,12 +7,16 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
struct Function; struct Function;
namespace script {
struct CompilationUnit;
}
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch
namespace c10 { namespace c10 {
template<class Key, class Value> class Dict; template<class Key, class Value> class Dict;
template<class T> class List; template<class T> class List;
struct IValue; struct IValue;
struct ClassType;
namespace ivalue { namespace ivalue {
struct Tuple; struct Tuple;
struct Future; struct Future;
@ -546,6 +550,20 @@ private:
bool is_intrusive_ptr; bool is_intrusive_ptr;
}; };
// An owning pointer to a Class. Just a pair of shared_ptrs to the class type
// and its owning CU, so that the class type is guaranteed to stay alive as long
// as we hold this object.
struct StrongTypePtr {
StrongTypePtr(
std::shared_ptr<torch::jit::script::CompilationUnit> cu,
std::shared_ptr<ClassType> type)
: cu_(std::move(cu)), type_(type) {
TORCH_INTERNAL_ASSERT(cu_);
TORCH_INTERNAL_ASSERT(type_);
}
std::shared_ptr<torch::jit::script::CompilationUnit> cu_;
std::shared_ptr<ClassType> type_;
};
} }
#include <ATen/core/ivalue_inl.h> #include <ATen/core/ivalue_inl.h>

View File

@ -14,6 +14,9 @@
namespace torch { namespace torch {
namespace jit { namespace jit {
struct Function; struct Function;
namespace script {
struct CompilationUnit;
}
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch
namespace c10 { namespace c10 {
@ -268,16 +271,20 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
// temporary way to break cyclic dependencies in modules by forcing the deletion // temporary way to break cyclic dependencies in modules by forcing the deletion
// of functions when the module object is destructed // of functions when the module object is destructed
typedef void (*OnDelete)(ivalue::Object*); typedef void (*OnDelete)(ivalue::Object*);
Object(std::shared_ptr<ClassType> type, size_t numSlots, OnDelete on_delete) Object(
StrongTypePtr type,
size_t numSlots,
OnDelete on_delete)
: type_(std::move(type)), on_delete_(on_delete) { : type_(std::move(type)), on_delete_(on_delete) {
slots_.resize(numSlots); slots_.resize(numSlots);
} }
static c10::intrusive_ptr<Object> create( static c10::intrusive_ptr<Object> create(
std::shared_ptr<ClassType> type, StrongTypePtr type,
size_t numSlots, size_t numSlots,
OnDelete on_delete = nullptr) { OnDelete on_delete = nullptr) {
return c10::make_intrusive<Object>(std::move(type), numSlots, on_delete); return c10::make_intrusive<Object>(
std::move(type), numSlots, on_delete);
} }
/** /**
@ -321,14 +328,19 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
return slots_; return slots_;
} }
std::shared_ptr<ClassType> type() const { std::shared_ptr<ClassType> type() const {
return type_; return type_.type_;
}
std::shared_ptr<torch::jit::script::CompilationUnit> compilation_unit() {
return type_.cu_;
} }
// temporarily defined in class_type.cpp to // temporarily defined in class_type.cpp to
// ensure Modules do not leak memory // ensure Modules do not leak memory
~Object(); ~Object();
private: private:
void resizeObject(size_t slot); void resizeObject(size_t slot);
std::shared_ptr<ClassType> type_; StrongTypePtr type_;
std::vector<IValue> slots_; std::vector<IValue> slots_;
OnDelete on_delete_; OnDelete on_delete_;
}; };

View File

@ -431,7 +431,9 @@ inline IValue toIValue(
auto classType = type->expect<ClassType>(); auto classType = type->expect<ClassType>();
// 1. create a bare ivalue // 1. create a bare ivalue
const size_t numAttrs = classType->numAttributes(); const size_t numAttrs = classType->numAttributes();
auto userObj = c10::ivalue::Object::create(classType, numAttrs); auto userObj = c10::ivalue::Object::create(
c10::StrongTypePtr(classType->compilation_unit(), classType),
numAttrs);
// 2. copy all the contained types // 2. copy all the contained types
for (size_t slot = 0; slot < numAttrs; slot++) { for (size_t slot = 0; slot < numAttrs; slot++) {
@ -574,8 +576,8 @@ inline py::object toPyObject(IValue&& ivalue) {
return std::move(py_dict); return std::move(py_dict);
} else if (ivalue.isObject()) { } else if (ivalue.isObject()) {
const auto obj = std::move(ivalue).toObject(); const auto obj = std::move(ivalue).toObject();
auto& pyCu = script::CompilationUnit::_get_python_cu(); auto pyCu = script::CompilationUnit::_get_python_cu();
const auto classType = pyCu.get_class(c10::QualifiedName(obj->name())); const auto classType = pyCu->get_class(c10::QualifiedName(obj->name()));
AT_ASSERT(classType); AT_ASSERT(classType);
auto pyClass = auto pyClass =
py::module::import("torch.jit").attr("_get_script_class")(obj->name()); py::module::import("torch.jit").attr("_get_script_class")(obj->name());

View File

@ -1074,7 +1074,8 @@ RegisterOperators reg(
const auto type = node->output()->type()->expect<ClassType>(); const auto type = node->output()->type()->expect<ClassType>();
const size_t numAttrs = type->numAttributes(); const size_t numAttrs = type->numAttributes();
return [type, numAttrs](Stack& stack) { return [type, numAttrs](Stack& stack) {
auto userObj = c10::ivalue::Object::create(type, numAttrs); auto userObj = c10::ivalue::Object::create(
c10::StrongTypePtr(type->compilation_unit(), type), numAttrs);
push(stack, std::move(userObj)); push(stack, std::move(userObj));
return 0; return 0;
}; };

View File

@ -162,17 +162,17 @@ struct TORCH_API CompilationUnit {
* Right now there is a single compilation unit that owns all ScriptClasses * Right now there is a single compilation unit that owns all ScriptClasses
* defined in Python. Below are accessors methods for it. * defined in Python. Below are accessors methods for it.
*/ */
static const CompilationUnit& _get_python_cu_const() { static std::shared_ptr<CompilationUnit> _get_python_cu_const() {
return _get_python_cu(); return _get_python_cu();
} }
static CompilationUnit& _get_python_cu() { static std::shared_ptr<CompilationUnit> _get_python_cu() {
static CompilationUnit pyCu; static auto pyCu = std::make_shared<CompilationUnit>();
return pyCu; return pyCu;
} }
// For testing: clear all Python-defined classes to ensure that unit tests // For testing: clear all Python-defined classes to ensure that unit tests
// have isolation. // have isolation.
static void _clear_python_cu() { static void _clear_python_cu() {
_get_python_cu().classes_.clear(); _get_python_cu()->classes_.clear();
} }
private: private:

View File

@ -148,11 +148,11 @@ struct PythonResolver : public Resolver {
annotations, annotations,
qualifiedName, qualifiedName,
TupleType::namedTupleSchemaFromNamesAndTypes(qualifiedName, fields, annotations)); TupleType::namedTupleSchemaFromNamesAndTypes(qualifiedName, fields, annotations));
CompilationUnit::_get_python_cu().register_class(tt); CompilationUnit::_get_python_cu()->register_class(tt);
return tt; return tt;
} }
return CompilationUnit::_get_python_cu().get_class(qualifiedName); return CompilationUnit::_get_python_cu()->get_class(qualifiedName);
} }
private: private:
@ -321,8 +321,7 @@ void addFunctionToModule(Module& module, const StrongFunctionPtr& func) {
auto graph = func.function_->graph()->copy(); auto graph = func.function_->graph()->copy();
auto v = graph->insertInput(0, "self"); auto v = graph->insertInput(0, "self");
v->setType(module.module_object()->type()); v->setType(module.module_object()->type());
module.module_object()->type()->compilation_unit()->create_function( module.module_object()->compilation_unit()->create_function("forward", graph);
"forward", graph);
} }
void initJitScriptBindings(PyObject* module) { void initJitScriptBindings(PyObject* module) {
@ -707,7 +706,7 @@ void initJitScriptBindings(PyObject* module) {
auto cu = std::make_shared<CompilationUnit>(); auto cu = std::make_shared<CompilationUnit>();
auto classType = auto classType =
ClassType::create(c10::QualifiedName(qualifiedName), cu); ClassType::create(c10::QualifiedName(qualifiedName), cu);
CompilationUnit::_get_python_cu().register_class(classType); CompilationUnit::_get_python_cu()->register_class(classType);
std::vector<ResolverPtr> rcbs; std::vector<ResolverPtr> rcbs;
std::vector<Def> methodDefs; std::vector<Def> methodDefs;
for (const auto& def : classDef.defs()) { for (const auto& def : classDef.defs()) {
@ -761,7 +760,7 @@ void initJitScriptBindings(PyObject* module) {
const std::vector<at::Tensor>& constant_table, const std::vector<at::Tensor>& constant_table,
const Self& self) { const Self& self) {
import_functions( import_functions(
CompilationUnit::_get_python_cu_const(), *CompilationUnit::_get_python_cu_const(),
cu, cu,
std::make_shared<Source>(src), std::make_shared<Source>(src),
constant_table, constant_table,

View File

@ -192,6 +192,10 @@ std::pair<std::shared_ptr<Graph>, std::vector<at::Tensor>> Method::_lowered_grap
return std::make_pair(result.first, loadTensors(result.second)); return std::make_pair(result.first, loadTensors(result.second));
} }
static void clearMethods(c10::ivalue::Object* self) {
self->compilation_unit()->drop_all_functions();
}
void Module::define(const std::string& src, const ResolverPtr& resolver) { void Module::define(const std::string& src, const ResolverPtr& resolver) {
class_compilation_unit()->define( class_compilation_unit()->define(
src, src,
@ -292,7 +296,8 @@ IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const {
// Create a bare object with correct number of slots // Create a bare object with correct number of slots
const size_t numAttrs = classType->numAttributes(); const size_t numAttrs = classType->numAttributes();
auto obj = c10::ivalue::Object::create(classType, numAttrs); auto obj = c10::ivalue::Object::create(
c10::StrongTypePtr(class_compilation_unit(), classType), numAttrs);
// Invoke the `__init__()` of the class with the arguments provided. // Invoke the `__init__()` of the class with the arguments provided.
Stack stackWithSelf = {obj}; Stack stackWithSelf = {obj};

View File

@ -106,9 +106,6 @@ struct TORCH_API Method {
// first-class function in class_compilation_unit() // first-class function in class_compilation_unit()
Function* function_; Function* function_;
}; };
static void clearMethods(c10::ivalue::Object* self) {
self->type()->compilation_unit()->drop_all_functions();
}
struct TORCH_API Module { struct TORCH_API Module {
Module(std::string class_name) Module(std::string class_name)
@ -306,11 +303,6 @@ struct TORCH_API Module {
std::unordered_map<TypePtr, TypePtr>& type_remap, std::unordered_map<TypePtr, TypePtr>& type_remap,
std::vector<std::string> names = {}) const; std::vector<std::string> names = {}) const;
void clone_method(
const Module& orig,
const std::string& name,
const std::unordered_map<TypePtr, TypePtr>& type_remap);
void clone_method(const Module& orig, const std::string& name); void clone_method(const Module& orig, const std::string& name);
at::optional<EntityType> kind_of(const std::string& name) const { at::optional<EntityType> kind_of(const std::string& name) const {
@ -334,11 +326,8 @@ struct TORCH_API Module {
ClassTypePtr type() const { ClassTypePtr type() const {
return module_object()->type(); return module_object()->type();
} }
std::shared_ptr<CompilationUnit> class_compilation_unit() { std::shared_ptr<CompilationUnit> class_compilation_unit() const {
return module_object()->type()->compilation_unit(); return module_object()->compilation_unit();
}
std::shared_ptr<const CompilationUnit> class_compilation_unit() const {
return module_object()->type()->compilation_unit();
} }
// so that C++ users can easily add methods // so that C++ users can easily add methods
@ -362,6 +351,10 @@ struct TORCH_API Module {
} }
private: private:
void clone_method(
const Module& orig,
const std::string& name,
const std::unordered_map<TypePtr, TypePtr>& type_remap);
static const char* toString(EntityType t) { static const char* toString(EntityType t) {
switch (t) { switch (t) {
case EntityType::MODULE: case EntityType::MODULE:
@ -428,14 +421,15 @@ struct TORCH_API Module {
const c10::optional<at::ScalarType>& dtype, const c10::optional<at::ScalarType>& dtype,
bool non_blocking); bool non_blocking);
static void clearMethods(c10::ivalue::Object* self) {
self->compilation_unit()->drop_all_functions();
}
static ModulePtr create_module_object(std::string class_name) { static ModulePtr create_module_object(std::string class_name) {
auto cu = std::make_shared<CompilationUnit>();
auto cls = ClassType::create(
QualifiedName(std::move(class_name)), cu, /*is_module=*/true);
return c10::ivalue::Object::create( return c10::ivalue::Object::create(
ClassType::create( c10::StrongTypePtr(std::move(cu), std::move(cls)), 0, clearMethods);
QualifiedName(std::move(class_name)),
std::make_shared<CompilationUnit>(),
/*is_module=*/true),
0,
clearMethods);
} }
// mutable be we lazily initialize in module_object. // mutable be we lazily initialize in module_object.
mutable ModulePtr module_value_; mutable ModulePtr module_value_;

View File

@ -516,8 +516,8 @@ std::shared_ptr<SugaredValue> toSugaredValue(
if (py::cast<bool>(isClass)) { if (py::cast<bool>(isClass)) {
py::str qualifiedName = py::str qualifiedName =
py::module::import("torch.jit").attr("_qualified_name")(obj); py::module::import("torch.jit").attr("_qualified_name")(obj);
auto& pyCu = CompilationUnit::_get_python_cu(); auto pyCu = CompilationUnit::_get_python_cu();
if (auto classType = pyCu.get_class(c10::QualifiedName(qualifiedName))) { if (auto classType = pyCu->get_class(c10::QualifiedName(qualifiedName))) {
return std::make_shared<ClassValue>(classType); return std::make_shared<ClassValue>(classType);
} }
} }