From f5919dba458d0fc3577c857f5e06e2618729e88c Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Thu, 4 Jul 2019 17:07:52 -0700 Subject: [PATCH] refactoring of module/object (#22203) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22203 ghimport-source-id: 6b3807ac8aa53df2fdd770b43d8e54b8f0d69c20 Test Plan: Imported from OSS Differential Revision: D15998760 Pulled By: suo fbshipit-source-id: dd51edbcb66561189ae9d94a129434092bcad01b --- aten/src/ATen/core/ivalue.cpp | 6 ++-- aten/src/ATen/core/ivalue.h | 18 +++++++++++ aten/src/ATen/core/ivalue_inl.h | 22 ++++++++++--- torch/csrc/jit/pybind_utils.h | 8 +++-- torch/csrc/jit/register_prim_ops.cpp | 3 +- torch/csrc/jit/script/compilation_unit.h | 8 ++--- torch/csrc/jit/script/init.cpp | 11 +++---- torch/csrc/jit/script/module.cpp | 7 +++- torch/csrc/jit/script/module.h | 32 ++++++++----------- .../csrc/jit/script/python_sugared_value.cpp | 4 +-- 10 files changed, 75 insertions(+), 44 deletions(-) diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index b3e377c2974..83d5b333d04 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -129,16 +129,16 @@ void IValue::dump() const { std::string ivalue::Object::name() const { - return this->type_->qualname(); + return this->type_.type_->qualname(); } IValue ivalue::Object::getAttr(const std::string& name) const { - const size_t slot = type_->getAttributeSlot(name); + const size_t slot = type_.type_->getAttributeSlot(name); return getSlot(slot); } void ivalue::Object::setAttr(const std::string& name, IValue v) { - const size_t slot = type_->getAttributeSlot(name); + const size_t slot = type_.type_->getAttributeSlot(name); setSlot(slot, std::move(v)); } diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index f8e0696be3b..5ad7def0973 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -7,12 +7,16 @@ namespace torch { namespace jit { struct Function; +namespace script { +struct CompilationUnit; +} } // namespace jit } // namespace torch namespace c10 { template class Dict; template class List; struct IValue; +struct ClassType; namespace ivalue { struct Tuple; struct Future; @@ -546,6 +550,20 @@ private: bool is_intrusive_ptr; }; +// An owning pointer to a Class. Just a pair of shared_ptrs to the class type +// and its owning CU, so that the class type is guaranteed to stay alive as long +// as we hold this object. +struct StrongTypePtr { + StrongTypePtr( + std::shared_ptr cu, + std::shared_ptr type) + : cu_(std::move(cu)), type_(type) { + TORCH_INTERNAL_ASSERT(cu_); + TORCH_INTERNAL_ASSERT(type_); + } + std::shared_ptr cu_; + std::shared_ptr type_; +}; } #include diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index bafb9957d3d..ad9a2097243 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -14,6 +14,9 @@ namespace torch { namespace jit { struct Function; +namespace script { +struct CompilationUnit; +} } // namespace jit } // namespace torch namespace c10 { @@ -268,16 +271,20 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target { // temporary way to break cyclic dependencies in modules by forcing the deletion // of functions when the module object is destructed typedef void (*OnDelete)(ivalue::Object*); - Object(std::shared_ptr type, size_t numSlots, OnDelete on_delete) + Object( + StrongTypePtr type, + size_t numSlots, + OnDelete on_delete) : type_(std::move(type)), on_delete_(on_delete) { slots_.resize(numSlots); } static c10::intrusive_ptr create( - std::shared_ptr type, + StrongTypePtr type, size_t numSlots, OnDelete on_delete = nullptr) { - return c10::make_intrusive(std::move(type), numSlots, on_delete); + return c10::make_intrusive( + std::move(type), numSlots, on_delete); } /** @@ -321,14 +328,19 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target { return slots_; } std::shared_ptr type() const { - return type_; + return type_.type_; + } + + std::shared_ptr compilation_unit() { + return type_.cu_; } // temporarily defined in class_type.cpp to // ensure Modules do not leak memory ~Object(); + private: void resizeObject(size_t slot); - std::shared_ptr type_; + StrongTypePtr type_; std::vector slots_; OnDelete on_delete_; }; diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index 78bf4df53e7..0e2ba73510f 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -431,7 +431,9 @@ inline IValue toIValue( auto classType = type->expect(); // 1. create a bare ivalue const size_t numAttrs = classType->numAttributes(); - auto userObj = c10::ivalue::Object::create(classType, numAttrs); + auto userObj = c10::ivalue::Object::create( + c10::StrongTypePtr(classType->compilation_unit(), classType), + numAttrs); // 2. copy all the contained types for (size_t slot = 0; slot < numAttrs; slot++) { @@ -574,8 +576,8 @@ inline py::object toPyObject(IValue&& ivalue) { return std::move(py_dict); } else if (ivalue.isObject()) { const auto obj = std::move(ivalue).toObject(); - auto& pyCu = script::CompilationUnit::_get_python_cu(); - const auto classType = pyCu.get_class(c10::QualifiedName(obj->name())); + auto pyCu = script::CompilationUnit::_get_python_cu(); + const auto classType = pyCu->get_class(c10::QualifiedName(obj->name())); AT_ASSERT(classType); auto pyClass = py::module::import("torch.jit").attr("_get_script_class")(obj->name()); diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 6bf1850eb9d..b760d5c7294 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -1074,7 +1074,8 @@ RegisterOperators reg( const auto type = node->output()->type()->expect(); const size_t numAttrs = type->numAttributes(); return [type, numAttrs](Stack& stack) { - auto userObj = c10::ivalue::Object::create(type, numAttrs); + auto userObj = c10::ivalue::Object::create( + c10::StrongTypePtr(type->compilation_unit(), type), numAttrs); push(stack, std::move(userObj)); return 0; }; diff --git a/torch/csrc/jit/script/compilation_unit.h b/torch/csrc/jit/script/compilation_unit.h index 13c4a765f36..f76ccf90de4 100644 --- a/torch/csrc/jit/script/compilation_unit.h +++ b/torch/csrc/jit/script/compilation_unit.h @@ -162,17 +162,17 @@ struct TORCH_API CompilationUnit { * Right now there is a single compilation unit that owns all ScriptClasses * defined in Python. Below are accessors methods for it. */ - static const CompilationUnit& _get_python_cu_const() { + static std::shared_ptr _get_python_cu_const() { return _get_python_cu(); } - static CompilationUnit& _get_python_cu() { - static CompilationUnit pyCu; + static std::shared_ptr _get_python_cu() { + static auto pyCu = std::make_shared(); return pyCu; } // For testing: clear all Python-defined classes to ensure that unit tests // have isolation. static void _clear_python_cu() { - _get_python_cu().classes_.clear(); + _get_python_cu()->classes_.clear(); } private: diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 1607908ee04..41321961647 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -148,11 +148,11 @@ struct PythonResolver : public Resolver { annotations, qualifiedName, TupleType::namedTupleSchemaFromNamesAndTypes(qualifiedName, fields, annotations)); - CompilationUnit::_get_python_cu().register_class(tt); + CompilationUnit::_get_python_cu()->register_class(tt); return tt; } - return CompilationUnit::_get_python_cu().get_class(qualifiedName); + return CompilationUnit::_get_python_cu()->get_class(qualifiedName); } private: @@ -321,8 +321,7 @@ void addFunctionToModule(Module& module, const StrongFunctionPtr& func) { auto graph = func.function_->graph()->copy(); auto v = graph->insertInput(0, "self"); v->setType(module.module_object()->type()); - module.module_object()->type()->compilation_unit()->create_function( - "forward", graph); + module.module_object()->compilation_unit()->create_function("forward", graph); } void initJitScriptBindings(PyObject* module) { @@ -707,7 +706,7 @@ void initJitScriptBindings(PyObject* module) { auto cu = std::make_shared(); auto classType = ClassType::create(c10::QualifiedName(qualifiedName), cu); - CompilationUnit::_get_python_cu().register_class(classType); + CompilationUnit::_get_python_cu()->register_class(classType); std::vector rcbs; std::vector methodDefs; for (const auto& def : classDef.defs()) { @@ -761,7 +760,7 @@ void initJitScriptBindings(PyObject* module) { const std::vector& constant_table, const Self& self) { import_functions( - CompilationUnit::_get_python_cu_const(), + *CompilationUnit::_get_python_cu_const(), cu, std::make_shared(src), constant_table, diff --git a/torch/csrc/jit/script/module.cpp b/torch/csrc/jit/script/module.cpp index 18ebcfbe45b..bfe9758c17d 100644 --- a/torch/csrc/jit/script/module.cpp +++ b/torch/csrc/jit/script/module.cpp @@ -192,6 +192,10 @@ std::pair, std::vector> Method::_lowered_grap return std::make_pair(result.first, loadTensors(result.second)); } +static void clearMethods(c10::ivalue::Object* self) { + self->compilation_unit()->drop_all_functions(); +} + void Module::define(const std::string& src, const ResolverPtr& resolver) { class_compilation_unit()->define( src, @@ -292,7 +296,8 @@ IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const { // Create a bare object with correct number of slots const size_t numAttrs = classType->numAttributes(); - auto obj = c10::ivalue::Object::create(classType, numAttrs); + auto obj = c10::ivalue::Object::create( + c10::StrongTypePtr(class_compilation_unit(), classType), numAttrs); // Invoke the `__init__()` of the class with the arguments provided. Stack stackWithSelf = {obj}; diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index 38bb49c2073..ff1f1e16580 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -106,9 +106,6 @@ struct TORCH_API Method { // first-class function in class_compilation_unit() Function* function_; }; -static void clearMethods(c10::ivalue::Object* self) { - self->type()->compilation_unit()->drop_all_functions(); -} struct TORCH_API Module { Module(std::string class_name) @@ -306,11 +303,6 @@ struct TORCH_API Module { std::unordered_map& type_remap, std::vector names = {}) const; - void clone_method( - const Module& orig, - const std::string& name, - const std::unordered_map& type_remap); - void clone_method(const Module& orig, const std::string& name); at::optional kind_of(const std::string& name) const { @@ -334,11 +326,8 @@ struct TORCH_API Module { ClassTypePtr type() const { return module_object()->type(); } - std::shared_ptr class_compilation_unit() { - return module_object()->type()->compilation_unit(); - } - std::shared_ptr class_compilation_unit() const { - return module_object()->type()->compilation_unit(); + std::shared_ptr class_compilation_unit() const { + return module_object()->compilation_unit(); } // so that C++ users can easily add methods @@ -362,6 +351,10 @@ struct TORCH_API Module { } private: + void clone_method( + const Module& orig, + const std::string& name, + const std::unordered_map& type_remap); static const char* toString(EntityType t) { switch (t) { case EntityType::MODULE: @@ -428,14 +421,15 @@ struct TORCH_API Module { const c10::optional& dtype, bool non_blocking); + static void clearMethods(c10::ivalue::Object* self) { + self->compilation_unit()->drop_all_functions(); + } static ModulePtr create_module_object(std::string class_name) { + auto cu = std::make_shared(); + auto cls = ClassType::create( + QualifiedName(std::move(class_name)), cu, /*is_module=*/true); return c10::ivalue::Object::create( - ClassType::create( - QualifiedName(std::move(class_name)), - std::make_shared(), - /*is_module=*/true), - 0, - clearMethods); + c10::StrongTypePtr(std::move(cu), std::move(cls)), 0, clearMethods); } // mutable be we lazily initialize in module_object. mutable ModulePtr module_value_; diff --git a/torch/csrc/jit/script/python_sugared_value.cpp b/torch/csrc/jit/script/python_sugared_value.cpp index 789a451b85c..9385e08059f 100644 --- a/torch/csrc/jit/script/python_sugared_value.cpp +++ b/torch/csrc/jit/script/python_sugared_value.cpp @@ -516,8 +516,8 @@ std::shared_ptr toSugaredValue( if (py::cast(isClass)) { py::str qualifiedName = py::module::import("torch.jit").attr("_qualified_name")(obj); - auto& pyCu = CompilationUnit::_get_python_cu(); - if (auto classType = pyCu.get_class(c10::QualifiedName(qualifiedName))) { + auto pyCu = CompilationUnit::_get_python_cu(); + if (auto classType = pyCu->get_class(c10::QualifiedName(qualifiedName))) { return std::make_shared(classType); } }