mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Generalize host allocator to be device-agnostic (#123079)
# Motivation According to [[RFC] Intel GPU Runtime Upstreaming for Allocator](https://github.com/pytorch/pytorch/issues/116322), we would like to generalize device and host allocator to be device-agnostic. We prioritize the host allocator as it is simpler and more native than the device allocator. In this PR, we intend to refactor the host allocator to make it be shared across different backends. In 2nd PR, we will support host allocator on XPU backend. # Design The previous design: - `CUDAHostAllocatorWrapper` inherits from `c10::Allocator`, and `CUDAHostAllocator` is an implementation of `CUDAHostAllocatorWrapper`. The design in this PR: - `CachingHostAllocatorImpl` is an interface that implements the caching host allocator logic that can be sharable across each backend. - `CachingHostAllocatorInterface` inherits from `c10::Allocator` as an interface and accepts `CachingHostAllocatorImpl` as its implementation. - `CUDACachingHostAllocator` is a CUDA host allocator whose implementation is `CUDACachingHostAllocatorImpl` which is specialized from `CachingHostAllocatorImpl`. This design can - share most code of caching mechanism across different backends, and - keep the flexibility to expand its exclusive feature on each backend. # Additional Context In addition, we will continue to generalize the device allocator in the next stage. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123079 Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/albanD, https://github.com/gujinghui
This commit is contained in:
parent
82e0153487
commit
b7f898c4a6
380
aten/src/ATen/core/CachingHostAllocator.h
Normal file
380
aten/src/ATen/core/CachingHostAllocator.h
Normal file
|
|
@ -0,0 +1,380 @@
|
|||
#include <c10/core/Allocator.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
#include <c10/util/llvmMathExtras.h>
|
||||
|
||||
#include <deque>
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
|
||||
namespace at {
|
||||
|
||||
/**
|
||||
* HostBlock is typically a fundamental memory block used in pinned memory. It
|
||||
* is likely related to Event and Stream of device runtime. It is probably a
|
||||
* base struct or interface that can be inherited and extended by each backend.
|
||||
*/
|
||||
template <typename S>
|
||||
struct HostBlock {
|
||||
// constructor for search key
|
||||
HostBlock(size_t size) : size_(size) {}
|
||||
|
||||
HostBlock(size_t size, void* ptr) : size_(size), ptr_(ptr) {}
|
||||
|
||||
std::mutex mutex_;
|
||||
size_t size_{0}; // block size in bytes
|
||||
void* ptr_{nullptr}; // memory address
|
||||
bool allocated_{false}; // in-use flag
|
||||
size_t event_count_{0}; // number of related events
|
||||
ska::flat_hash_set<S> streams_; // streams on which the block was used
|
||||
};
|
||||
|
||||
/**
|
||||
* ComparatorSize is used for lookup support in the set of host memory blocks
|
||||
* using the block size.
|
||||
*/
|
||||
template <typename B>
|
||||
struct ComparatorSize {
|
||||
bool operator()(const B* a, const B* b) const {
|
||||
if (a->size_ != b->size_) {
|
||||
return a->size_ < b->size_;
|
||||
}
|
||||
return (uintptr_t)a->ptr_ < (uintptr_t)b->ptr_;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Note [HostAllocator design]
|
||||
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
* We have three key data structures - the free list which stores blocks that
|
||||
* are not currently used, the block list which stores all blocks that have been
|
||||
* allocated, and the event queue which stores runtime events and their
|
||||
* corresponding blocks.
|
||||
*
|
||||
* Each of these are protected by a separate mutex. The key design principles
|
||||
* are to 1) only hold each mutex for the minimal amount of time possible, 2)
|
||||
* never do any possible expensive operations (such as CUDA runtime API calls)
|
||||
* while holding the lock.
|
||||
*
|
||||
* There are four public methods: allocate, free, record_event and empty_cache.
|
||||
* 1) In the allocate path, we first check to see if we can service our
|
||||
* request from this free list, and otherwise we create a new block with
|
||||
* allocate_host_memory.
|
||||
* 2) In the free path, we insert events (if required) into the event queue,
|
||||
* and if possible insert our block back into the free list. In allocate, we
|
||||
* first eagerly query events until we find one that is not ready, and insert
|
||||
* the corresponding block onto the free list if all the events recorded for a
|
||||
* block are ready.
|
||||
* 3) In the record_event path, we simply insert the given stream into the set
|
||||
* of streams tracked by the specified block. This set of streams is then
|
||||
* consumed in the free path.
|
||||
* 4) In the empty_cache path, we flush any available blocks into the free
|
||||
* list. Remove all element of free list, then remove them from block list and
|
||||
* release the associated pinned memory allocation via free_block.
|
||||
*
|
||||
* We generalize the caching host allocator into two parts: interface and
|
||||
* implementation. For any new backend looking to integrate with host allocator
|
||||
* and reuse caching mechanism, these two parts are necessary to be specialized.
|
||||
*
|
||||
* For the implementation, we provide a CachingHostAllocatorImpl struct
|
||||
* to abstract the caching mechanism. Any backend needs to provide a customized
|
||||
* implementation by specializing its own public functions and the related
|
||||
* runtime functions. Its template parameter S represents runtime Stream, E
|
||||
* denotes runtime Event, B indicates the fundamental memory block, and C
|
||||
* signifies the sorting compartor algorithm for the memory blocks.
|
||||
*
|
||||
* For the interface, we provide a CachingHostAllocatorInterface struct as an
|
||||
* interface. Any backend needs to derive its own host allocator from this
|
||||
* interface. Its template parameter T refers to an implementation that
|
||||
* inherited from CachingHostAllocatorImpl.
|
||||
*
|
||||
* So this design can share the caching mechanism across each backend, and
|
||||
* provide flexibility to each backend. A backend can choose to follow this
|
||||
* implementation or reuse them by extending and overriding them as necessary.
|
||||
* Taking CUDA as an example, it specializes runtime related functions to reuse
|
||||
* the caching mechanism. Additionally, it extends the allocator's functionality
|
||||
* by adding the allocWithCudaHostRegister function to support page-locking the
|
||||
* memory range used by CUDA. Of course, you can also refer to
|
||||
* XPUCachingHostAllocator, which is a host caching allocator supported on XPU
|
||||
* backend, to implement a basic host caching allocator.
|
||||
*
|
||||
* Some of the invariants here are less strict than they could be - for example,
|
||||
* we do not enforce that free(Block* block) => block->event_count == 0. This is
|
||||
* for compatibility reasons, and we can explore enforcing these in subsequent
|
||||
* versions.
|
||||
*
|
||||
* Note that this caching host allocator does not split larger allocations into
|
||||
* smaller blocks, unlike the caching device allocator.
|
||||
*/
|
||||
|
||||
template <
|
||||
typename S,
|
||||
typename E,
|
||||
typename B = HostBlock<S>,
|
||||
typename C = ComparatorSize<B>>
|
||||
struct CachingHostAllocatorImpl {
|
||||
virtual ~CachingHostAllocatorImpl() = default;
|
||||
|
||||
public:
|
||||
// return data_ptr and block pair.
|
||||
virtual std::pair<void*, void*> allocate(size_t size) {
|
||||
if (size == 0) {
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
|
||||
process_events();
|
||||
|
||||
// First, try to allocate from the free list
|
||||
auto* block = get_free_block(size);
|
||||
if (block) {
|
||||
return {block->ptr_, reinterpret_cast<void*>(block)};
|
||||
}
|
||||
|
||||
// Round up the allocation to the nearest power of two to improve reuse.
|
||||
size_t roundSize = c10::llvm::PowerOf2Ceil(size);
|
||||
void* ptr = nullptr;
|
||||
allocate_host_memory(roundSize, &ptr);
|
||||
|
||||
// Then, create a new block.
|
||||
block = new B(roundSize, ptr);
|
||||
block->allocated_ = true;
|
||||
|
||||
add_allocated_block(block);
|
||||
return {block->ptr_, reinterpret_cast<void*>(block)};
|
||||
}
|
||||
|
||||
virtual void free(void* ctx) {
|
||||
if (!ctx) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Note: we can assume that free is correctly paired with alloc, and thus we
|
||||
// do not need to look up the ctx in blocks_.
|
||||
auto* block = reinterpret_cast<B*>(ctx);
|
||||
|
||||
c10::optional<std::vector<E>> events;
|
||||
{
|
||||
std::lock_guard<std::mutex> g(block->mutex_);
|
||||
block->allocated_ = false;
|
||||
if (block->streams_.empty()) {
|
||||
TORCH_INTERNAL_ASSERT(block->event_count_ == 0);
|
||||
} else {
|
||||
events = std::vector<E>();
|
||||
events->reserve(block->streams_.size());
|
||||
for (auto stream : block->streams_) {
|
||||
record_stream(events, stream);
|
||||
}
|
||||
block->event_count_ += events->size();
|
||||
block->streams_.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if (!events) {
|
||||
std::lock_guard<std::mutex> g(free_list_mutex_);
|
||||
free_list_.insert(block);
|
||||
} else {
|
||||
// restore these events that record by used streams.
|
||||
std::lock_guard<std::mutex> g(events_mutex_);
|
||||
for (auto&& event : *events) {
|
||||
events_.emplace_front(std::move(event), block);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
virtual bool record_event(void* ptr, void* ctx, S stream) {
|
||||
auto* block = reinterpret_cast<B*>(ctx);
|
||||
|
||||
// Note: we need to check if the passed-in `ctx` is valid. This is because
|
||||
// `record_event` (via `CachingHostAllocator_recordEvent`) can be invoked on
|
||||
// an arbitrary tensor, and is not guaranteed to correspond to a pinned
|
||||
// memory allocation. Therefore, we need to check that `ctx` is valid before
|
||||
// proceeding.
|
||||
{
|
||||
std::lock_guard<std::mutex> g(blocks_mutex_);
|
||||
if (blocks_.find(block) != blocks_.end()) {
|
||||
// Now we know this object is safe to access.
|
||||
std::lock_guard<std::mutex> gb(block->mutex_);
|
||||
TORCH_INTERNAL_ASSERT(block->allocated_);
|
||||
block->streams_.insert(stream);
|
||||
return true;
|
||||
}
|
||||
auto it = ptr_to_block_.find(ptr);
|
||||
if (it != ptr_to_block_.end()) {
|
||||
block = it->second;
|
||||
std::lock_guard<std::mutex> g(block->mutex_);
|
||||
TORCH_INTERNAL_ASSERT(block->allocated_);
|
||||
block->streams_.insert(stream);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual void empty_cache() {
|
||||
// Flush any available blocks into the free_list.
|
||||
process_events();
|
||||
|
||||
// Remove all elements from the free list, remove them from the blocks
|
||||
// list, and free the associated pinned memory allocation. This requires
|
||||
// concurrently holding both the free list mutex and the blocks mutex, and
|
||||
// is the only function that concurrently holds multiple mutexes.
|
||||
std::lock(free_list_mutex_, blocks_mutex_);
|
||||
std::lock_guard<std::mutex> gf(free_list_mutex_, std::adopt_lock);
|
||||
std::lock_guard<std::mutex> gb(blocks_mutex_, std::adopt_lock);
|
||||
|
||||
std::vector<B*> blocks_to_remove(free_list_.begin(), free_list_.end());
|
||||
free_list_.clear();
|
||||
for (auto* block : blocks_to_remove) {
|
||||
blocks_.erase(block);
|
||||
ptr_to_block_.erase(block->ptr_);
|
||||
free_block(block);
|
||||
delete block;
|
||||
}
|
||||
}
|
||||
|
||||
virtual void copy_data(void* dest, const void* src, std::size_t count) const {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for copy_data");
|
||||
}
|
||||
|
||||
private:
|
||||
virtual void add_allocated_block(B* block) {
|
||||
std::lock_guard<std::mutex> g(blocks_mutex_);
|
||||
blocks_.insert(block);
|
||||
ptr_to_block_.insert({block->ptr_, block});
|
||||
}
|
||||
|
||||
virtual B* get_free_block(size_t size) {
|
||||
std::lock_guard<std::mutex> g(free_list_mutex_);
|
||||
B key(size);
|
||||
auto it = free_list_.lower_bound(&key);
|
||||
if (it != free_list_.end()) {
|
||||
B* block = *it;
|
||||
block->allocated_ = true;
|
||||
free_list_.erase(it);
|
||||
return block;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
virtual void process_events() {
|
||||
|
||||
while (true) {
|
||||
// Avoid calling cudaEventDestroy while holding a mutex, so move
|
||||
// intermediate events out of the lock into this object.
|
||||
// process the last event
|
||||
c10::optional<std::pair<E, B*>> processed;
|
||||
{
|
||||
std::lock_guard<std::mutex> g(events_mutex_);
|
||||
if (!events_.empty()) {
|
||||
processed = std::move(events_.back());
|
||||
events_.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
if (!processed) {
|
||||
return;
|
||||
}
|
||||
|
||||
// otherwise, query the event
|
||||
{
|
||||
// now, see if we can handle this element
|
||||
auto& event = processed->first;
|
||||
if (!query_event(event)) {
|
||||
// push the event onto the back if it's not ready.
|
||||
{
|
||||
std::lock_guard<std::mutex> g(events_mutex_);
|
||||
events_.push_back(std::move(*processed));
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Process the events.
|
||||
TORCH_INTERNAL_ASSERT(processed);
|
||||
auto* block = processed->second;
|
||||
bool available = false;
|
||||
{
|
||||
std::lock_guard<std::mutex> g(block->mutex_);
|
||||
TORCH_INTERNAL_ASSERT(!block->allocated_)
|
||||
block->event_count_--;
|
||||
if (block->event_count_ == 0) {
|
||||
available = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (available) {
|
||||
std::lock_guard<std::mutex> g(free_list_mutex_);
|
||||
free_list_.insert(block);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* These following functions are runtime-related. */
|
||||
|
||||
// Allocate page-locked memory on the host.
|
||||
virtual void allocate_host_memory(size_t size, void** ptr) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "Not implemented for allocate_host_memory");
|
||||
}
|
||||
|
||||
// Free block and release the pointer contained in block.
|
||||
virtual void free_block(B* block) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for free_block");
|
||||
}
|
||||
|
||||
// Record an event on stream and store event into events.
|
||||
virtual void record_stream(c10::optional<std::vector<E>>& events, S stream) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for record_stream");
|
||||
}
|
||||
|
||||
// Query event if it is completed.
|
||||
virtual bool query_event(E& event) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
|
||||
}
|
||||
|
||||
alignas(64) std::mutex blocks_mutex_;
|
||||
ska::flat_hash_set<B*> blocks_; // block list
|
||||
ska::flat_hash_map<void*, B*> ptr_to_block_;
|
||||
|
||||
// Note: sharding this mutex seems to be profitable in heavily multi-threaded
|
||||
// scenarios.
|
||||
alignas(64) std::mutex free_list_mutex_;
|
||||
// Note: an alternative datastructure can yield significant wins here in
|
||||
// microbenchmarks.
|
||||
std::set<B*, C> free_list_; // free list
|
||||
|
||||
alignas(64) std::mutex events_mutex_;
|
||||
std::deque<std::pair<E, B*>> events_; // event queue paired with block
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct CachingHostAllocatorInterface : public at::Allocator {
|
||||
CachingHostAllocatorInterface() :impl_(std::make_unique<T>()) {}
|
||||
|
||||
at::DataPtr allocate(size_t size) override {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for allocate");
|
||||
}
|
||||
|
||||
void free(void* ctx) {
|
||||
impl_->free(ctx);
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
bool record_event(void* ptr, void* ctx, S stream) {
|
||||
return impl_->record_event(ptr, ctx, stream);
|
||||
}
|
||||
|
||||
void empty_cache() {
|
||||
impl_->empty_cache();
|
||||
}
|
||||
|
||||
void copy_data(void* dest, const void* src, std::size_t count)
|
||||
const override {
|
||||
impl_->copy_data(dest, src, count);
|
||||
}
|
||||
|
||||
std::unique_ptr<T> impl_;
|
||||
};
|
||||
|
||||
} // namespace at
|
||||
|
|
@ -8,34 +8,11 @@
|
|||
#include <c10/cuda/CUDAAllocatorConfig.h>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <stdint.h>
|
||||
#include <deque>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
namespace at::cuda {
|
||||
namespace {
|
||||
|
||||
struct BlockSize {
|
||||
size_t size_{0};
|
||||
void* ptr_{nullptr};
|
||||
};
|
||||
|
||||
struct Block {
|
||||
size_t size_{0};
|
||||
void* ptr_{nullptr};
|
||||
|
||||
std::mutex mutex_;
|
||||
bool allocated_{false};
|
||||
size_t event_count_{0};
|
||||
std::unordered_set<at::cuda::CUDAStream> streams_;
|
||||
};
|
||||
|
||||
// Note: cudaEventCreate when concurrently invoked from multiple threads can be
|
||||
// very expensive (at least on certain device/driver combinations). Thus, we a)
|
||||
// serialize event creation at a per-device level, and b) pool the events to
|
||||
|
|
@ -89,81 +66,12 @@ class EventPool {
|
|||
std::vector<PerDevicePool> pools_;
|
||||
};
|
||||
|
||||
// Used for heterogenous lookup support in the free list.
|
||||
struct BlockComparator {
|
||||
using is_transparent = void;
|
||||
bool operator()(const Block* a, const Block* b) const {
|
||||
if (a->size_ != b->size_) {
|
||||
return a->size_ < b->size_;
|
||||
}
|
||||
return (uintptr_t)a->ptr_ < (uintptr_t)b->ptr_;
|
||||
}
|
||||
using Block = HostBlock<CUDAStream>;
|
||||
|
||||
// Transparent overloads
|
||||
bool operator()(const Block* a, BlockSize b) const {
|
||||
if (a->size_ != b.size_) {
|
||||
return a->size_ < b.size_;
|
||||
}
|
||||
return (uintptr_t)a->ptr_ < (uintptr_t)b.ptr_;
|
||||
}
|
||||
bool operator()(BlockSize a, const Block* b) const {
|
||||
if (a.size_ != b->size_) {
|
||||
return a.size_ < b->size_;
|
||||
}
|
||||
return (uintptr_t)a.ptr_ < (uintptr_t)b->ptr_;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Note [CUDAHostAllocator design]
|
||||
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
* We have three key data structures - the free list which stores blocks that
|
||||
* are not currently used, the block list which stores all blocks that have been
|
||||
* allocated, and the event queue which stores CUDA events and their
|
||||
* corresponding blocks.
|
||||
*
|
||||
* Each of these are protected by a separate mutex. The key design principles
|
||||
* are to 1) only hold each mutex for the minimal amount of time possible, 2)
|
||||
* never do any possible expensive operations (such as CUDA runtime API calls)
|
||||
* while holding the lock.
|
||||
*
|
||||
* There are three public methods: allocate, free, and record_event. In the
|
||||
* allocate path, we first check to see if we can service our request from this
|
||||
* free list, and otherwise we create a new block with cudaHostAlloc. In the
|
||||
* free path, we insert events (if required) into the event queue, and if
|
||||
* possible insert our block back into the free list. In allocate, we first
|
||||
* eagerly query events until we find one that is not ready, and insert the
|
||||
* corresponding block onto the free list if all the events recorded for a
|
||||
* block are ready. In the record_event path, we simply insert the given
|
||||
* stream into the set of streams tracked by the specified block. This set of
|
||||
* streams is then consumed in the free path.
|
||||
*
|
||||
* Some of the invariants here are less strict than they could be - for example,
|
||||
* we do not enforce that free(Block* block) => block->event_count == 0. This is
|
||||
* for compatibility reasons, and we can explore enforcing these in subsequent
|
||||
* versions.
|
||||
*/
|
||||
class CUDAHostAllocator {
|
||||
public:
|
||||
std::pair<void*, void*> allocate(size_t size) {
|
||||
if (size == 0) {
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
|
||||
process_events();
|
||||
|
||||
// First, try to allocate from the free list
|
||||
{
|
||||
std::lock_guard<std::mutex> g(free_list_mutex_);
|
||||
auto it = free_list_.lower_bound(BlockSize{size, nullptr});
|
||||
if (it != free_list_.end()) {
|
||||
auto block = *it;
|
||||
block->allocated_ = true;
|
||||
free_list_.erase(it);
|
||||
return {block->ptr_, reinterpret_cast<void*>(block)};
|
||||
}
|
||||
}
|
||||
// Then, create a new block.
|
||||
struct CUDACachingHostAllocatorImpl
|
||||
: public CachingHostAllocatorImpl<CUDAStream, EventPool::Event> {
|
||||
private:
|
||||
void allocate_host_memory(size_t size, void** ptr) override {
|
||||
// Pinned memory pointers allocated by any device can be directly used by
|
||||
// any other device, regardless of the current device at the time of
|
||||
// allocation, since we assume unified addressing. So we grab any existing
|
||||
|
|
@ -176,192 +84,49 @@ class CUDAHostAllocator {
|
|||
at::Device(at::DeviceType::CUDA, *primary_ctx_device_index));
|
||||
}
|
||||
|
||||
// Round up the allocation to the nearest power of two to improve reuse.
|
||||
size_t roundSize = c10::llvm::PowerOf2Ceil(size);
|
||||
void* ptr = nullptr;
|
||||
if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
|
||||
pinned_use_cuda_host_register()) {
|
||||
allocWithCudaHostRegister(&ptr, roundSize);
|
||||
allocWithCudaHostRegister(ptr, size);
|
||||
} else {
|
||||
// Use cudaHostAlloc for allocating pinned memory (global lock in driver)
|
||||
C10_CUDA_CHECK(cudaHostAlloc(&ptr, roundSize, cudaHostAllocDefault));
|
||||
C10_CUDA_CHECK(cudaHostAlloc(ptr, size, cudaHostAllocDefault));
|
||||
}
|
||||
|
||||
auto block = new Block();
|
||||
block->size_ = roundSize;
|
||||
block->ptr_ = ptr;
|
||||
block->allocated_ = true;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> g(blocks_mutex_);
|
||||
blocks_.insert(block);
|
||||
ptr_to_block_.insert({block->ptr_, block});
|
||||
}
|
||||
return {block->ptr_, reinterpret_cast<void*>(block)};
|
||||
}
|
||||
|
||||
void free(void* ctx) {
|
||||
if (!ctx) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Note: we can assume that free is correctly paired with alloc,
|
||||
// and thus we do not need to look up the ctx in blocks_.
|
||||
auto* block = reinterpret_cast<Block*>(ctx);
|
||||
|
||||
c10::optional<std::vector<EventPool::Event>> events;
|
||||
{
|
||||
std::lock_guard<std::mutex> g(block->mutex_);
|
||||
block->allocated_ = false;
|
||||
if (block->streams_.empty()) {
|
||||
TORCH_INTERNAL_ASSERT(block->event_count_ == 0);
|
||||
} else {
|
||||
events = std::vector<EventPool::Event>();
|
||||
events->reserve(block->streams_.size());
|
||||
for (auto stream : block->streams_) {
|
||||
auto event = event_pool_.get(stream.device_index());
|
||||
event->record(stream);
|
||||
events->push_back(std::move(event));
|
||||
}
|
||||
block->event_count_ += events->size();
|
||||
block->streams_.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if (!events) {
|
||||
std::lock_guard<std::mutex> g(free_list_mutex_);
|
||||
free_list_.insert(block);
|
||||
void free_block(Block* block) override {
|
||||
if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
|
||||
pinned_use_cuda_host_register()) {
|
||||
void* ptr = block->ptr_;
|
||||
AT_CUDA_CHECK(cudaHostUnregister(ptr));
|
||||
free(ptr);
|
||||
} else {
|
||||
std::lock_guard<std::mutex> g(cuda_events_mutex_);
|
||||
for (auto&& event : *events) {
|
||||
cuda_events_.emplace_front(std::move(event), block);
|
||||
}
|
||||
AT_CUDA_CHECK(cudaFreeHost(block->ptr_));
|
||||
}
|
||||
}
|
||||
|
||||
bool record_event(void* ptr, void* ctx, at::cuda::CUDAStream stream) {
|
||||
auto* block = reinterpret_cast<Block*>(ctx);
|
||||
|
||||
// Note: we need to check if the passed-in `ctx` is valid. This is because
|
||||
// `record_event` (via `CachingHostAllocator_recordEvent`) can be invoked on
|
||||
// an arbitrary tensor, and is not guaranteed to correspond to a pinned
|
||||
// memory allocation. Therefore, we need to check that `ctx` is valid before
|
||||
// proceeding.
|
||||
{
|
||||
std::lock_guard<std::mutex> g(blocks_mutex_);
|
||||
if (blocks_.find(block) != blocks_.end()) {
|
||||
// Now we know this object is safe to access.
|
||||
std::lock_guard<std::mutex> gb(block->mutex_);
|
||||
TORCH_INTERNAL_ASSERT(block->allocated_);
|
||||
block->streams_.insert(stream);
|
||||
return true;
|
||||
}
|
||||
auto it = ptr_to_block_.find(ptr);
|
||||
if (it != ptr_to_block_.end()) {
|
||||
block = it->second;
|
||||
std::lock_guard<std::mutex> g(block->mutex_);
|
||||
TORCH_INTERNAL_ASSERT(block->allocated_);
|
||||
block->streams_.insert(stream);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
void record_stream(
|
||||
c10::optional<std::vector<EventPool::Event>>& events,
|
||||
CUDAStream stream) override {
|
||||
auto event = create_event_internal(stream.device_index());
|
||||
event->record(stream);
|
||||
events->push_back(std::move(event));
|
||||
}
|
||||
|
||||
void empty_cache() {
|
||||
// Flush any available blocks into the free_list.
|
||||
process_events();
|
||||
|
||||
// Release cached events from the event pool.
|
||||
event_pool_.empty_cache();
|
||||
|
||||
// Remove all elements from the free list, remove them from the blocks
|
||||
// list, and free the associated pinned memory allocation. This requires
|
||||
// concurrently holding both the free list mutex and the blocks mutex, and
|
||||
// is the only function that concurrently holds multiple mutexes.
|
||||
std::lock(free_list_mutex_, blocks_mutex_);
|
||||
std::lock_guard<std::mutex> gf(free_list_mutex_, std::adopt_lock);
|
||||
std::lock_guard<std::mutex> gb(blocks_mutex_, std::adopt_lock);
|
||||
|
||||
std::vector<Block*> blocks_to_remove(free_list_.begin(), free_list_.end());
|
||||
free_list_.clear();
|
||||
for (auto* block : blocks_to_remove) {
|
||||
blocks_.erase(block);
|
||||
ptr_to_block_.erase(block->ptr_);
|
||||
if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
|
||||
pinned_use_cuda_host_register()) {
|
||||
void* ptr = block->ptr_;
|
||||
AT_CUDA_CHECK(cudaHostUnregister(ptr));
|
||||
free(ptr);
|
||||
} else {
|
||||
AT_CUDA_CHECK(cudaFreeHost(block->ptr_));
|
||||
}
|
||||
delete block;
|
||||
bool query_event(EventPool::Event& event) override {
|
||||
cudaError_t err = cudaEventQuery(*event);
|
||||
if (err == cudaErrorNotReady) {
|
||||
(void)cudaGetLastError(); // clear CUDA error
|
||||
return false;
|
||||
} else if (err != cudaSuccess) {
|
||||
C10_CUDA_CHECK(err);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void copy_data(void* dest, const void* src, std::size_t count) const {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for CUDAHostAllocator");
|
||||
}
|
||||
|
||||
private:
|
||||
void process_events() {
|
||||
while (true) {
|
||||
// Avoid calling cudaEventDestroy while holding a mutex, so move
|
||||
// intermediate events out of the lock into this object.
|
||||
c10::optional<std::pair<EventPool::Event, Block*>> processed;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> g(cuda_events_mutex_);
|
||||
if (!cuda_events_.empty()) {
|
||||
processed = std::move(cuda_events_.back());
|
||||
cuda_events_.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
if (!processed) {
|
||||
return;
|
||||
}
|
||||
|
||||
// otherwise, query the event
|
||||
{
|
||||
// now, see if we can handle this element
|
||||
auto& event = processed->first;
|
||||
cudaError_t err = cudaEventQuery(*event);
|
||||
if (err == cudaErrorNotReady) {
|
||||
(void)cudaGetLastError(); // clear CUDA error
|
||||
// push the event onto the back of the queue if it's not
|
||||
// ready. TODO: do we need some debouncing logic to avoid allocating
|
||||
// threads repeatedly spinning on an event?
|
||||
{
|
||||
std::lock_guard<std::mutex> g(cuda_events_mutex_);
|
||||
cuda_events_.push_back(std::move(*processed));
|
||||
}
|
||||
return;
|
||||
} else if (err != cudaSuccess) {
|
||||
C10_CUDA_CHECK(err);
|
||||
}
|
||||
}
|
||||
|
||||
// Process the events.
|
||||
TORCH_INTERNAL_ASSERT(processed);
|
||||
auto* block = processed->second;
|
||||
bool available = false;
|
||||
{
|
||||
std::lock_guard<std::mutex> g(block->mutex_);
|
||||
TORCH_INTERNAL_ASSERT(!block->allocated_)
|
||||
block->event_count_--;
|
||||
if (block->event_count_ == 0) {
|
||||
available = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (available) {
|
||||
std::lock_guard<std::mutex> g(free_list_mutex_);
|
||||
free_list_.insert(block);
|
||||
}
|
||||
}
|
||||
EventPool::Event create_event_internal(DeviceIndex idx) {
|
||||
// Leak the event pool to avoid shutdown issue.
|
||||
static auto* event_pool = new EventPool();
|
||||
return event_pool->get(idx);
|
||||
}
|
||||
|
||||
TaskThreadPool* getThreadPool() {
|
||||
|
|
@ -406,7 +171,7 @@ class CUDAHostAllocator {
|
|||
"");
|
||||
}
|
||||
|
||||
inline void allocWithCudaHostRegister(void** ptr, size_t roundSize) {
|
||||
void allocWithCudaHostRegister(void** ptr, size_t roundSize) {
|
||||
// Here we do regular allocation, pre-fault/map the pages, and then do
|
||||
// cudaHostRegister with GPU mapping flags to lock the pages, so we
|
||||
// can minimize the cost for the cuda global lock.
|
||||
|
|
@ -427,13 +192,19 @@ class CUDAHostAllocator {
|
|||
for (size_t i = 0; i < numMapThreads; i++) {
|
||||
promises.emplace_back();
|
||||
futures.push_back(promises[i].get_future());
|
||||
auto task = [this, i, ptr, roundSize, numMapThreads, pageSize, &promises]() mutable {
|
||||
auto task = [this,
|
||||
i,
|
||||
ptr,
|
||||
roundSize,
|
||||
numMapThreads,
|
||||
pageSize,
|
||||
&promises]() mutable {
|
||||
mapPagesForRegister(
|
||||
*ptr,
|
||||
roundSize,
|
||||
i, // thread task-id
|
||||
numMapThreads,
|
||||
pageSize);
|
||||
*ptr,
|
||||
roundSize,
|
||||
i, // thread task-id
|
||||
numMapThreads,
|
||||
pageSize);
|
||||
// set the promise when mapping pages are done
|
||||
promises[i].set_value();
|
||||
};
|
||||
|
|
@ -450,66 +221,48 @@ class CUDAHostAllocator {
|
|||
// Register the mapped pages using cudaHostRegister
|
||||
registerPages(*ptr, roundSize);
|
||||
}
|
||||
|
||||
EventPool event_pool_;
|
||||
|
||||
alignas(64) std::mutex blocks_mutex_;
|
||||
std::unordered_set<Block*> blocks_;
|
||||
std::unordered_map<void*, Block*> ptr_to_block_;
|
||||
// Note: sharding this mutex seems to be profitable in heavily multi-threaded
|
||||
// scenarios.
|
||||
alignas(64) std::mutex free_list_mutex_;
|
||||
// Note: an alternative datastructure can yield significant wins here in
|
||||
// microbenchmarks.
|
||||
std::set<Block*, BlockComparator> free_list_;
|
||||
|
||||
alignas(64) std::mutex cuda_events_mutex_;
|
||||
std::deque<std::pair<EventPool::Event, Block*>> cuda_events_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
void raw_local_deleter(void* ptr);
|
||||
|
||||
static CUDAHostAllocator& getCUDAHostAllocator() {
|
||||
// leak and don't worry about shutdown
|
||||
static auto* r = new CUDAHostAllocator();
|
||||
return *r;
|
||||
struct CUDACachingHostAllocator final
|
||||
: public CachingHostAllocatorInterface<CUDACachingHostAllocatorImpl> {
|
||||
at::DataPtr allocate(size_t size) override {
|
||||
auto ptr_and_ctx = impl_->allocate(size);
|
||||
return {
|
||||
ptr_and_ctx.first,
|
||||
ptr_and_ctx.second,
|
||||
&raw_local_deleter,
|
||||
at::DeviceType::CPU};
|
||||
}
|
||||
};
|
||||
|
||||
CUDACachingHostAllocator caching_host_allocator;
|
||||
|
||||
static inline CUDACachingHostAllocator& getCUDACachingHostAllocator() {
|
||||
return caching_host_allocator;
|
||||
}
|
||||
|
||||
static void CUDAHostAllocatorDeleter(void* ctx) {
|
||||
getCUDAHostAllocator().free(ctx);
|
||||
void raw_local_deleter(void* ptr) {
|
||||
getCUDACachingHostAllocator().free(ptr);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
bool CachingHostAllocator_recordEvent(
|
||||
void* ptr,
|
||||
void* ctx,
|
||||
at::cuda::CUDAStream stream) {
|
||||
return getCUDAHostAllocator().record_event(ptr, ctx, stream);
|
||||
return getCUDACachingHostAllocator().record_event(ptr, ctx, stream);
|
||||
}
|
||||
|
||||
// Releases cached pinned memory allocations via cudaHostFree
|
||||
void CachingHostAllocator_emptyCache() {
|
||||
getCUDAHostAllocator().empty_cache();
|
||||
getCUDACachingHostAllocator().empty_cache();
|
||||
}
|
||||
|
||||
struct CUDAHostAllocatorWrapper final : public at::Allocator {
|
||||
at::DataPtr allocate(size_t size) override {
|
||||
auto ptr_and_ctx = getCUDAHostAllocator().allocate(size);
|
||||
return {
|
||||
ptr_and_ctx.first,
|
||||
ptr_and_ctx.second,
|
||||
&CUDAHostAllocatorDeleter,
|
||||
at::DeviceType::CPU};
|
||||
}
|
||||
|
||||
void copy_data(void* dest, const void* src, std::size_t count) const final {
|
||||
getCUDAHostAllocator().copy_data(dest, src, count);
|
||||
}
|
||||
};
|
||||
|
||||
static CUDAHostAllocatorWrapper cuda_host_allocator;
|
||||
|
||||
at::Allocator* getCachingHostAllocator() {
|
||||
return &cuda_host_allocator;
|
||||
return &getCUDACachingHostAllocator();
|
||||
}
|
||||
|
||||
} // namespace at::cuda
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/core/CachingHostAllocator.h>
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
|
|
@ -17,15 +18,14 @@ namespace at::cuda {
|
|||
// call between host and device, and passed the corresponding context from the
|
||||
// allocation. This is currently invoked by at::native::copy_kernel_cuda.
|
||||
//
|
||||
// Note that this allocator does not split larger allocations into smaller
|
||||
// blocks, unlike the caching device allocator.
|
||||
//
|
||||
TORCH_CUDA_CPP_API c10::Allocator* getCachingHostAllocator();
|
||||
|
||||
// Records an event in the specified stream. The allocation corresponding to the
|
||||
// input `ptr`/`ctx` will not be re-used until the event has occurred.
|
||||
TORCH_CUDA_CPP_API bool
|
||||
CachingHostAllocator_recordEvent(void* ptr, void* ctx, c10::cuda::CUDAStream stream);
|
||||
TORCH_CUDA_CPP_API bool CachingHostAllocator_recordEvent(
|
||||
void* ptr,
|
||||
void* ctx,
|
||||
c10::cuda::CUDAStream stream);
|
||||
|
||||
// Releases cached pinned memory allocations via cudaHostFree
|
||||
TORCH_CUDA_CPP_API void CachingHostAllocator_emptyCache();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user