mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/144014 Approved by: https://github.com/Skylion007, https://github.com/albanD
48 lines
1.4 KiB
C++
48 lines
1.4 KiB
C++
#include <ATen/ATen.h>
|
|
#include <ATen/CachedTensorUtils.h>
|
|
|
|
#include <c10/util/flat_hash_map.h>
|
|
|
|
namespace at::caching {
|
|
|
|
|
|
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
|
|
|
|
static bool cached_tensorimpls_enabled = false;
|
|
|
|
// Like `cached_casts` in autocast_mode, we hash on the TensorImpl*
|
|
// and keep the pointer alive with a weakref value.
|
|
static ska::flat_hash_map<TensorImpl*, weakref_type> cached_tensorimpls;
|
|
static std::mutex cached_tensorimpl_mutex;
|
|
|
|
|
|
bool is_cached_tensor(const at::Tensor& t) {
|
|
if (!cached_tensorimpls_enabled) {
|
|
return false;
|
|
}
|
|
const std::lock_guard<std::mutex> lock(cached_tensorimpl_mutex);
|
|
return cached_tensorimpls.count(t.unsafeGetTensorImpl());
|
|
}
|
|
|
|
void add_cached_tensor(const at::Tensor& t) {
|
|
TORCH_INTERNAL_ASSERT(cached_tensorimpls_enabled);
|
|
const std::lock_guard<std::mutex> lock(cached_tensorimpl_mutex);
|
|
cached_tensorimpls.emplace(t.unsafeGetTensorImpl(), weakref_type(t.getIntrusivePtr()));
|
|
}
|
|
|
|
void remove_cached_tensor(const at::Tensor& t) {
|
|
TORCH_INTERNAL_ASSERT(cached_tensorimpls_enabled);
|
|
const std::lock_guard<std::mutex> lock(cached_tensorimpl_mutex);
|
|
cached_tensorimpls.erase(t.unsafeGetTensorImpl());
|
|
}
|
|
|
|
void set_cached_tensors_enabled(bool enabled) {
|
|
cached_tensorimpls_enabled = enabled;
|
|
}
|
|
|
|
size_t adjusted_use_count(const at::Tensor& t) {
|
|
return t.use_count() - (is_cached_tensor(t) ? 1 : 0);
|
|
}
|
|
|
|
} // namespace at::caching
|