mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This implements a new `wait_stream` API in Work that matches how `wait` works for ProcessGroupNCCL for CPU based backends such as Gloo. The idea is to support Gloo communication overlap in FSDPv2/HSDP with minimal changes to FSDP. There was a previous attempt to make FSDPv2 use Work.wait but given the extensive stream semantics used it doesn't play nicely. https://github.com/pytorch/pytorch/pull/148780 This uses a "Baton" CUDA kernel which spinlocks on a pinned CPU tensor waiting for it to be set. Test plan: ``` pytest test/distributed/test_c10d_gloo.py -v -k wait_stream pytest test/distributed/test_c10d_nccl.py -v -k wait_stream ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156883 Approved by: https://github.com/kwen2501, https://github.com/fduwjj
39 lines
863 B
C++
39 lines
863 B
C++
#pragma once
|
|
|
|
#include <chrono>
|
|
#include <memory>
|
|
|
|
#include <c10/util/Registry.h>
|
|
|
|
namespace c10d::cuda {
|
|
|
|
enum StreamBlockStatus : int32_t {
|
|
UNKNOWN = 0,
|
|
RUNNING = 1,
|
|
TIMED_OUT = 2,
|
|
ABORTED = 3,
|
|
};
|
|
|
|
/*
|
|
StreamBlock implements a baton that will block a the active CUDA stream
|
|
until aborted by the main process.
|
|
*/
|
|
class TORCH_API StreamBlock {
|
|
public:
|
|
virtual ~StreamBlock() = default;
|
|
virtual void abort() = 0;
|
|
virtual StreamBlockStatus status() = 0;
|
|
};
|
|
|
|
std::unique_ptr<StreamBlock> block_stream(std::chrono::milliseconds timeout);
|
|
|
|
// Declare a registry so we can call the CUDA StreamBlock API from CPU only code
|
|
// (i.e. ProcessGroup/Work objects in libtorch_cpu).
|
|
// The implementation lives defined in StreamBlock.cu.
|
|
TORCH_DECLARE_REGISTRY(
|
|
StreamBlockRegistry,
|
|
StreamBlock,
|
|
std::chrono::milliseconds);
|
|
|
|
} // namespace c10d::cuda
|