[Openreg][PrivateUse1] Fix releasing tensor issue when using pin_memory (#151091)

As the title stated.

Related PR: https://github.com/pytorch/pytorch/pull/147066

Co-authored-by: Zhenbin Lin <lin-zhenbin@qq.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151091
Approved by: https://github.com/albanD
ghstack dependencies: #151005, #151007
This commit is contained in:
FFFrog 2025-04-16 09:57:56 +08:00 committed by PyTorch MergeBot
parent c7400d0026
commit e229ce34c4
3 changed files with 33 additions and 41 deletions

View File

@ -9,4 +9,33 @@ using openreg_ptr_t = uint64_t;
void set_impl_factory(PyObject* factory); void set_impl_factory(PyObject* factory);
py::function get_method(const char* name); py::function get_method(const char* name);
static constexpr char kFreeMethod[] = "free";
static constexpr char kHostFreeMethod[] = "hostFree";
template <const char* name>
static void ReportAndDelete(void* ptr) {
if (!ptr || !Py_IsInitialized()) {
return;
}
py::gil_scoped_acquire acquire;
PyObject *type = nullptr, *value = nullptr, *traceback = nullptr;
// Always stash, this will be a no-op if there is no error
PyErr_Fetch(&type, &value, &traceback);
TORCH_CHECK(
get_method(name)(reinterpret_cast<openreg_ptr_t>(ptr)).cast<bool>(),
"Failed to free memory pointer at ",
ptr);
// If that user code raised an error, just print it without raising it
if (PyErr_Occurred()) {
PyErr_Print();
}
// Restore the original error
PyErr_Restore(type, value, traceback);
}
} // namespace openreg } // namespace openreg

View File

@ -25,23 +25,11 @@ struct HostAllocator final : at::Allocator {
get_method("hostMalloc")(nbytes).cast<openreg_ptr_t>()); get_method("hostMalloc")(nbytes).cast<openreg_ptr_t>());
TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on host."); TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on host.");
} }
return {data, data, &ReportAndDelete, at::Device(at::kCPU)}; return {data, data, &ReportAndDelete<kHostFreeMethod>, at::Device(at::kCPU)};
}
static void ReportAndDelete(void* ptr) {
if (!ptr) {
return;
}
py::gil_scoped_acquire acquire;
TORCH_CHECK(
get_method("hostFree")(reinterpret_cast<openreg_ptr_t>(ptr))
.cast<bool>(),
"Failed to free memory pointer at ",
ptr);
} }
at::DeleterFnPtr raw_deleter() const override { at::DeleterFnPtr raw_deleter() const override {
return &ReportAndDelete; return &ReportAndDelete<kHostFreeMethod>;
} }
void copy_data(void* dest, const void* src, std::size_t count) const final { void copy_data(void* dest, const void* src, std::size_t count) const final {

View File

@ -26,36 +26,11 @@ struct OpenRegAllocator final : at::Allocator {
TORCH_CHECK( TORCH_CHECK(
data, "Failed to allocator ", nbytes, " bytes on openreg device."); data, "Failed to allocator ", nbytes, " bytes on openreg device.");
} }
return {data, data, &ReportAndDelete, curr_device}; return {data, data, &ReportAndDelete<kFreeMethod>, curr_device};
}
static void ReportAndDelete(void* ptr) {
if (!ptr || !Py_IsInitialized()) {
return;
}
py::gil_scoped_acquire acquire;
PyObject *type = nullptr, *value = nullptr, *traceback = nullptr;
// Always stash, this will be a no-op if there is no error
PyErr_Fetch(&type, &value, &traceback);
TORCH_CHECK(
get_method("free")(reinterpret_cast<openreg_ptr_t>(ptr)).cast<bool>(),
"Failed to free memory pointer at ",
ptr);
// If that user code raised an error, just print it without raising it
if (PyErr_Occurred()) {
PyErr_Print();
}
// Restore the original error
PyErr_Restore(type, value, traceback);
} }
at::DeleterFnPtr raw_deleter() const override { at::DeleterFnPtr raw_deleter() const override {
return &ReportAndDelete; return &ReportAndDelete<kFreeMethod>;
} }
void copy_data(void* dest, const void* src, std::size_t count) const final { void copy_data(void* dest, const void* src, std::size_t count) const final {