custom class method holder should hold a unique_ptr (#35218)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35218

We should express the ownership semantics directly here. Using
`shared_ptr` makes it too easy to leak ownership by inadvertently
storing a copy.

Test Plan: Imported from OSS

Differential Revision: D20682673

Pulled By: suo

fbshipit-source-id: 32002ee515eb8bb7b37e6d0aac3c0695df4eec79
This commit is contained in:
Michael Suo 2020-03-27 16:49:44 -07:00 committed by Facebook GitHub Bot
parent b9adbb5002
commit cfcb63de34
3 changed files with 6 additions and 6 deletions

View File

@ -32,12 +32,12 @@ bool isCustomClass(const c10::IValue& v) {
getCustomClass(v.toObject()->type()->name()->qualifiedName());
}
std::vector<std::shared_ptr<jit::Function>>& customClassMethods() {
static std::vector<std::shared_ptr<jit::Function>> customClassMethods;
std::vector<std::unique_ptr<jit::Function>>& customClassMethods() {
static std::vector<std::unique_ptr<jit::Function>> customClassMethods;
return customClassMethods;
}
void registerCustomClassMethod(std::shared_ptr<jit::Function> fn) {
void registerCustomClassMethod(std::unique_ptr<jit::Function> fn) {
customClassMethods().emplace_back(std::move(fn));
}

View File

@ -223,15 +223,15 @@ class class_ {
typename c10::guts::infer_function_traits_t<Func>::return_type;
detail::BoxedProxy<RetType, Func>()(stack, func);
};
auto method = std::make_shared<jit::BuiltinOpFunction>(
auto method = std::make_unique<jit::BuiltinOpFunction>(
qualMethodName, std::move(schema), std::move(wrapped_func));
// Register the method here to keep the Method alive.
// ClassTypes do not hold ownership of their methods (normally it
// those are held by the CompilationUnit), so we need a proxy for
// that behavior here.
registerCustomClassMethod(method);
classTypePtr->addMethod(method.get());
registerCustomClassMethod(std::move(method));
}
std::string qualClassName;

View File

@ -136,7 +136,7 @@ inline void checkValidIdent(const std::string& str, const char *type) {
} // namespace detail
TORCH_API void registerCustomClass(at::ClassTypePtr class_type);
TORCH_API void registerCustomClassMethod(std::shared_ptr<jit::Function> method);
TORCH_API void registerCustomClassMethod(std::unique_ptr<jit::Function> method);
// Given a qualified name (e.g. __torch__.torch.classes.Foo), return
// the ClassType pointer to the Type that describes that custom class,