mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Enable clang-tidy on `torch/csrc/distributed/autograd/*` directory. Pull Request resolved: https://github.com/pytorch/pytorch/pull/137180 Approved by: https://github.com/Skylion007
163 lines
6.2 KiB
C++
163 lines
6.2 KiB
C++
#pragma once
|
|
|
|
#include <mutex>
|
|
#include <unordered_map>
|
|
|
|
#include <torch/csrc/distributed/autograd/context/context.h>
|
|
|
|
namespace torch::distributed::autograd {
|
|
|
|
// Singleton class per worker which is responsible for storing the distributed
|
|
// autograd context for each autograd pass and also cleans up data for an
|
|
// autograd pass once its done.
|
|
//
|
|
// Each autograd pass is assigned a unique autograd_context_id and all data for
|
|
// that pass (DistAutogradContext) is stored in this container indexed by the
|
|
// autograd_context_id. The autograd_context_id itself is a 64 bit globally
|
|
// unique id. The first 16 bits is the worker_id and the next 48 bits is an
|
|
// auto-incrementing id for each worker.
|
|
//
|
|
// This container is also responsible for maintaining a globally unique message
|
|
// id, which is used to associate send/recv autograd function pairs. The format
|
|
// is similar to the autograd_context_id where we have a 64 bit integer with
|
|
// first 16 bits being the worker id and next 48 bits are auto-incrementing.
|
|
class TORCH_API DistAutogradContainer {
|
|
public:
|
|
explicit DistAutogradContainer(uint32_t num_shards);
|
|
|
|
// One time initialization of the container.
|
|
static DistAutogradContainer& init(int64_t worker_id);
|
|
|
|
// Retrieve the singleton instance of the container, ensures we have
|
|
// initialized the container.
|
|
static DistAutogradContainer& getInstance();
|
|
|
|
// Create a new context for a distributed autograd pass.
|
|
const ContextPtr newContext();
|
|
|
|
// Clean up resources for a given context_id once the autograd pass is done.
|
|
// Sends RPC to other workers this worker knows about, telling them to clean
|
|
// up their context as well. Throws an exception if the context_id does not
|
|
// exist.
|
|
void releaseContext(int64_t context_id);
|
|
|
|
// Releases an autograd context if it is present on this node. Also sends RPC
|
|
// to other workers this worker knows about, telling them to clean up their
|
|
// context. Does nothing if it is not present.
|
|
void releaseContextIfPresent(int64_t context_id);
|
|
|
|
// Checks if the passed in context_id is valid.
|
|
void isValidContext(int64_t context_id);
|
|
|
|
// Retrieve the autograd context for a given context_id.
|
|
ContextPtr retrieveContext(int64_t context_id);
|
|
|
|
// Retrieves the currently active autograd context for the current thread.
|
|
ContextPtr currentContext();
|
|
|
|
// Checks whether or not the current thread has a valid autograd context.
|
|
bool hasValidContext() const;
|
|
|
|
// Generate a new autograd_message_id for send/recv autograd functions.
|
|
int64_t newAutogradMessageId();
|
|
|
|
// Creates a new autograd context with the provided context_id. If a context
|
|
// already exists with the provided context_id, we just return it.
|
|
// This does not set the current context for the current thread.
|
|
ContextPtr getOrCreateContext(int64_t context_id);
|
|
|
|
// Retrieves the maximum possible autograd_context_id/autograd_message_id that
|
|
// can be generated by this worker.
|
|
int64_t getMaxId();
|
|
|
|
// Retrieves the worker ID for this node
|
|
rpc::worker_id_t getWorkerId() const;
|
|
|
|
// Can set current context id if there is no valid context yet
|
|
static void setCurrentContextId(int64_t contextId);
|
|
|
|
// Forcibly sets the thread local current context id. Should only be used in
|
|
// cases where you know what you're doing and need to override the thread
|
|
// local. Otherwise, use setCurrentContextId instead.
|
|
static void forceCurrentContextId(int64_t contextId);
|
|
|
|
// Clear current context id
|
|
void clearCurrentContext();
|
|
|
|
// Returns the number of autograd contexts in the container.
|
|
size_t numAutogradContexts() const;
|
|
|
|
// Returns the current thread local context id for this thread.
|
|
static int64_t currentContextId();
|
|
|
|
DistAutogradContainer() = delete;
|
|
~DistAutogradContainer() = default;
|
|
DistAutogradContainer(const DistAutogradContainer&) = delete;
|
|
DistAutogradContainer& operator=(const DistAutogradContainer&) = delete;
|
|
DistAutogradContainer(DistAutogradContainer&&) = delete;
|
|
DistAutogradContainer& operator=(DistAutogradContainer&&) = delete;
|
|
|
|
private:
|
|
// Number of shards for the map storing autograd contexts. We'd like this
|
|
// to be a power of 2 and we don't expect a value much higher than the
|
|
// number of cores would provide much benefit.
|
|
static constexpr uint32_t kNumDefaultShards = 128;
|
|
|
|
// Use cache line size for alignment.
|
|
static constexpr int kCacheLineSize = 64;
|
|
|
|
// Structure holding one shard of the sharded autograd context map with its
|
|
// associated lock. Align to cache line size to avoid contention between
|
|
// adjacent entries.
|
|
struct alignas(kCacheLineSize) ContextsShard {
|
|
// Lock for this shard.
|
|
mutable std::mutex lock;
|
|
|
|
// Map storing autograd contexts for this shard.
|
|
std::unordered_map<int64_t, ContextPtr> contexts;
|
|
};
|
|
|
|
static DistAutogradContainer& getInstanceInternal();
|
|
|
|
// Retrieve the shard for given context_id.
|
|
ContextsShard& getShard(int64_t context_id);
|
|
|
|
// Sends an RPC to the workers that have a context corresponding to passed in
|
|
// context_id. This function should be called with the lock.
|
|
void sendReleaseContextRpc(
|
|
const std::unordered_set<rpc::worker_id_t>& workerIds,
|
|
int64_t context_id);
|
|
|
|
// Erase context_id from the autograd context map, and reset the thread local
|
|
// current context id if it corresponds to the passed in context id. This
|
|
// function should be called with the lock.
|
|
void eraseContextIdAndReset(ContextsShard& shard, int64_t context_id);
|
|
|
|
// Compute the number of shards for the autograd_contexts_ map.
|
|
static uint32_t computeNumShards();
|
|
|
|
// Auto incrementing context id used to identify unique autograd passes.
|
|
// Initialized with the first 16 bits being the worker_id.
|
|
std::atomic<int64_t> next_context_id_;
|
|
|
|
// Unique id to identify a worker in the distributed setting.
|
|
int16_t worker_id_;
|
|
|
|
// Whether or not the container has been initialized appropriately.
|
|
bool initialized_;
|
|
|
|
// Sharded autograd context map.
|
|
std::vector<ContextsShard> autograd_contexts_;
|
|
|
|
// Number of shards for the sharded autograd_contexts_ map.
|
|
uint32_t num_shards_;
|
|
|
|
// Autograd message id to identify unique send/recv autograd function pairs.
|
|
std::atomic<int64_t> next_autograd_message_id_;
|
|
|
|
// Maximum allowed value for autograd_context_id or autograd_message_id.
|
|
int64_t max_id_;
|
|
};
|
|
|
|
} // namespace torch::distributed::autograd
|