mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
custom class method holder should hold a unique_ptr (#35218)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35218 We should express the ownership semantics directly here. Using `shared_ptr` makes it too easy to leak ownership by inadvertently storing a copy. Test Plan: Imported from OSS Differential Revision: D20682673 Pulled By: suo fbshipit-source-id: 32002ee515eb8bb7b37e6d0aac3c0695df4eec79
This commit is contained in:
parent
b9adbb5002
commit
cfcb63de34
|
|
@ -32,12 +32,12 @@ bool isCustomClass(const c10::IValue& v) {
|
|||
getCustomClass(v.toObject()->type()->name()->qualifiedName());
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<jit::Function>>& customClassMethods() {
|
||||
static std::vector<std::shared_ptr<jit::Function>> customClassMethods;
|
||||
std::vector<std::unique_ptr<jit::Function>>& customClassMethods() {
|
||||
static std::vector<std::unique_ptr<jit::Function>> customClassMethods;
|
||||
return customClassMethods;
|
||||
}
|
||||
|
||||
void registerCustomClassMethod(std::shared_ptr<jit::Function> fn) {
|
||||
void registerCustomClassMethod(std::unique_ptr<jit::Function> fn) {
|
||||
customClassMethods().emplace_back(std::move(fn));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -223,15 +223,15 @@ class class_ {
|
|||
typename c10::guts::infer_function_traits_t<Func>::return_type;
|
||||
detail::BoxedProxy<RetType, Func>()(stack, func);
|
||||
};
|
||||
auto method = std::make_shared<jit::BuiltinOpFunction>(
|
||||
auto method = std::make_unique<jit::BuiltinOpFunction>(
|
||||
qualMethodName, std::move(schema), std::move(wrapped_func));
|
||||
|
||||
// Register the method here to keep the Method alive.
|
||||
// ClassTypes do not hold ownership of their methods (normally it
|
||||
// those are held by the CompilationUnit), so we need a proxy for
|
||||
// that behavior here.
|
||||
registerCustomClassMethod(method);
|
||||
classTypePtr->addMethod(method.get());
|
||||
registerCustomClassMethod(std::move(method));
|
||||
}
|
||||
|
||||
std::string qualClassName;
|
||||
|
|
|
|||
|
|
@ -136,7 +136,7 @@ inline void checkValidIdent(const std::string& str, const char *type) {
|
|||
} // namespace detail
|
||||
|
||||
TORCH_API void registerCustomClass(at::ClassTypePtr class_type);
|
||||
TORCH_API void registerCustomClassMethod(std::shared_ptr<jit::Function> method);
|
||||
TORCH_API void registerCustomClassMethod(std::unique_ptr<jit::Function> method);
|
||||
|
||||
// Given a qualified name (e.g. __torch__.torch.classes.Foo), return
|
||||
// the ClassType pointer to the Type that describes that custom class,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user