mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Openreg][PrivateUse1] Fix releasing tensor issue when using pin_memory (#151091)
As the title stated. Related PR: https://github.com/pytorch/pytorch/pull/147066 Co-authored-by: Zhenbin Lin <lin-zhenbin@qq.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/151091 Approved by: https://github.com/albanD ghstack dependencies: #151005, #151007
This commit is contained in:
parent
c7400d0026
commit
e229ce34c4
|
|
@ -9,4 +9,33 @@ using openreg_ptr_t = uint64_t;
|
||||||
void set_impl_factory(PyObject* factory);
|
void set_impl_factory(PyObject* factory);
|
||||||
py::function get_method(const char* name);
|
py::function get_method(const char* name);
|
||||||
|
|
||||||
|
static constexpr char kFreeMethod[] = "free";
|
||||||
|
static constexpr char kHostFreeMethod[] = "hostFree";
|
||||||
|
|
||||||
|
template <const char* name>
|
||||||
|
static void ReportAndDelete(void* ptr) {
|
||||||
|
if (!ptr || !Py_IsInitialized()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
py::gil_scoped_acquire acquire;
|
||||||
|
|
||||||
|
PyObject *type = nullptr, *value = nullptr, *traceback = nullptr;
|
||||||
|
// Always stash, this will be a no-op if there is no error
|
||||||
|
PyErr_Fetch(&type, &value, &traceback);
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
get_method(name)(reinterpret_cast<openreg_ptr_t>(ptr)).cast<bool>(),
|
||||||
|
"Failed to free memory pointer at ",
|
||||||
|
ptr);
|
||||||
|
|
||||||
|
// If that user code raised an error, just print it without raising it
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
PyErr_Print();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore the original error
|
||||||
|
PyErr_Restore(type, value, traceback);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace openreg
|
} // namespace openreg
|
||||||
|
|
|
||||||
|
|
@ -25,23 +25,11 @@ struct HostAllocator final : at::Allocator {
|
||||||
get_method("hostMalloc")(nbytes).cast<openreg_ptr_t>());
|
get_method("hostMalloc")(nbytes).cast<openreg_ptr_t>());
|
||||||
TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on host.");
|
TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on host.");
|
||||||
}
|
}
|
||||||
return {data, data, &ReportAndDelete, at::Device(at::kCPU)};
|
return {data, data, &ReportAndDelete<kHostFreeMethod>, at::Device(at::kCPU)};
|
||||||
}
|
|
||||||
|
|
||||||
static void ReportAndDelete(void* ptr) {
|
|
||||||
if (!ptr) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
py::gil_scoped_acquire acquire;
|
|
||||||
TORCH_CHECK(
|
|
||||||
get_method("hostFree")(reinterpret_cast<openreg_ptr_t>(ptr))
|
|
||||||
.cast<bool>(),
|
|
||||||
"Failed to free memory pointer at ",
|
|
||||||
ptr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
at::DeleterFnPtr raw_deleter() const override {
|
at::DeleterFnPtr raw_deleter() const override {
|
||||||
return &ReportAndDelete;
|
return &ReportAndDelete<kHostFreeMethod>;
|
||||||
}
|
}
|
||||||
|
|
||||||
void copy_data(void* dest, const void* src, std::size_t count) const final {
|
void copy_data(void* dest, const void* src, std::size_t count) const final {
|
||||||
|
|
|
||||||
|
|
@ -26,36 +26,11 @@ struct OpenRegAllocator final : at::Allocator {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
data, "Failed to allocator ", nbytes, " bytes on openreg device.");
|
data, "Failed to allocator ", nbytes, " bytes on openreg device.");
|
||||||
}
|
}
|
||||||
return {data, data, &ReportAndDelete, curr_device};
|
return {data, data, &ReportAndDelete<kFreeMethod>, curr_device};
|
||||||
}
|
|
||||||
|
|
||||||
static void ReportAndDelete(void* ptr) {
|
|
||||||
if (!ptr || !Py_IsInitialized()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
py::gil_scoped_acquire acquire;
|
|
||||||
|
|
||||||
PyObject *type = nullptr, *value = nullptr, *traceback = nullptr;
|
|
||||||
// Always stash, this will be a no-op if there is no error
|
|
||||||
PyErr_Fetch(&type, &value, &traceback);
|
|
||||||
|
|
||||||
TORCH_CHECK(
|
|
||||||
get_method("free")(reinterpret_cast<openreg_ptr_t>(ptr)).cast<bool>(),
|
|
||||||
"Failed to free memory pointer at ",
|
|
||||||
ptr);
|
|
||||||
|
|
||||||
// If that user code raised an error, just print it without raising it
|
|
||||||
if (PyErr_Occurred()) {
|
|
||||||
PyErr_Print();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Restore the original error
|
|
||||||
PyErr_Restore(type, value, traceback);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
at::DeleterFnPtr raw_deleter() const override {
|
at::DeleterFnPtr raw_deleter() const override {
|
||||||
return &ReportAndDelete;
|
return &ReportAndDelete<kFreeMethod>;
|
||||||
}
|
}
|
||||||
|
|
||||||
void copy_data(void* dest, const void* src, std::size_t count) const final {
|
void copy_data(void* dest, const void* src, std::size_t count) const final {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user