mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Openreg][PrivateUse1] Fix releasing tensor issue when using pin_memory (#151091)
As the title stated. Related PR: https://github.com/pytorch/pytorch/pull/147066 Co-authored-by: Zhenbin Lin <lin-zhenbin@qq.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/151091 Approved by: https://github.com/albanD ghstack dependencies: #151005, #151007
This commit is contained in:
parent
c7400d0026
commit
e229ce34c4
|
|
@ -9,4 +9,33 @@ using openreg_ptr_t = uint64_t;
|
|||
void set_impl_factory(PyObject* factory);
|
||||
py::function get_method(const char* name);
|
||||
|
||||
static constexpr char kFreeMethod[] = "free";
|
||||
static constexpr char kHostFreeMethod[] = "hostFree";
|
||||
|
||||
template <const char* name>
|
||||
static void ReportAndDelete(void* ptr) {
|
||||
if (!ptr || !Py_IsInitialized()) {
|
||||
return;
|
||||
}
|
||||
|
||||
py::gil_scoped_acquire acquire;
|
||||
|
||||
PyObject *type = nullptr, *value = nullptr, *traceback = nullptr;
|
||||
// Always stash, this will be a no-op if there is no error
|
||||
PyErr_Fetch(&type, &value, &traceback);
|
||||
|
||||
TORCH_CHECK(
|
||||
get_method(name)(reinterpret_cast<openreg_ptr_t>(ptr)).cast<bool>(),
|
||||
"Failed to free memory pointer at ",
|
||||
ptr);
|
||||
|
||||
// If that user code raised an error, just print it without raising it
|
||||
if (PyErr_Occurred()) {
|
||||
PyErr_Print();
|
||||
}
|
||||
|
||||
// Restore the original error
|
||||
PyErr_Restore(type, value, traceback);
|
||||
}
|
||||
|
||||
} // namespace openreg
|
||||
|
|
|
|||
|
|
@ -25,23 +25,11 @@ struct HostAllocator final : at::Allocator {
|
|||
get_method("hostMalloc")(nbytes).cast<openreg_ptr_t>());
|
||||
TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on host.");
|
||||
}
|
||||
return {data, data, &ReportAndDelete, at::Device(at::kCPU)};
|
||||
}
|
||||
|
||||
static void ReportAndDelete(void* ptr) {
|
||||
if (!ptr) {
|
||||
return;
|
||||
}
|
||||
py::gil_scoped_acquire acquire;
|
||||
TORCH_CHECK(
|
||||
get_method("hostFree")(reinterpret_cast<openreg_ptr_t>(ptr))
|
||||
.cast<bool>(),
|
||||
"Failed to free memory pointer at ",
|
||||
ptr);
|
||||
return {data, data, &ReportAndDelete<kHostFreeMethod>, at::Device(at::kCPU)};
|
||||
}
|
||||
|
||||
at::DeleterFnPtr raw_deleter() const override {
|
||||
return &ReportAndDelete;
|
||||
return &ReportAndDelete<kHostFreeMethod>;
|
||||
}
|
||||
|
||||
void copy_data(void* dest, const void* src, std::size_t count) const final {
|
||||
|
|
|
|||
|
|
@ -26,36 +26,11 @@ struct OpenRegAllocator final : at::Allocator {
|
|||
TORCH_CHECK(
|
||||
data, "Failed to allocator ", nbytes, " bytes on openreg device.");
|
||||
}
|
||||
return {data, data, &ReportAndDelete, curr_device};
|
||||
}
|
||||
|
||||
static void ReportAndDelete(void* ptr) {
|
||||
if (!ptr || !Py_IsInitialized()) {
|
||||
return;
|
||||
}
|
||||
|
||||
py::gil_scoped_acquire acquire;
|
||||
|
||||
PyObject *type = nullptr, *value = nullptr, *traceback = nullptr;
|
||||
// Always stash, this will be a no-op if there is no error
|
||||
PyErr_Fetch(&type, &value, &traceback);
|
||||
|
||||
TORCH_CHECK(
|
||||
get_method("free")(reinterpret_cast<openreg_ptr_t>(ptr)).cast<bool>(),
|
||||
"Failed to free memory pointer at ",
|
||||
ptr);
|
||||
|
||||
// If that user code raised an error, just print it without raising it
|
||||
if (PyErr_Occurred()) {
|
||||
PyErr_Print();
|
||||
}
|
||||
|
||||
// Restore the original error
|
||||
PyErr_Restore(type, value, traceback);
|
||||
return {data, data, &ReportAndDelete<kFreeMethod>, curr_device};
|
||||
}
|
||||
|
||||
at::DeleterFnPtr raw_deleter() const override {
|
||||
return &ReportAndDelete;
|
||||
return &ReportAndDelete<kFreeMethod>;
|
||||
}
|
||||
|
||||
void copy_data(void* dest, const void* src, std::size_t count) const final {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user