pytorch/torch/csrc/distributed/c10d/cuda/StreamBlock.hpp
Tristan Rice 1b3d69b59f Work: block_current_stream API (#156883)
This implements a new `wait_stream` API in Work that matches how `wait` works for ProcessGroupNCCL for CPU based backends such as Gloo.

The idea is to support Gloo communication overlap in FSDPv2/HSDP with minimal changes to FSDP.

There was a previous attempt to make FSDPv2 use Work.wait but given the extensive stream semantics used it doesn't play nicely. https://github.com/pytorch/pytorch/pull/148780

This uses a "Baton" CUDA kernel which spinlocks on a pinned CPU tensor waiting for it to be set.

Test plan:

```
pytest test/distributed/test_c10d_gloo.py -v -k wait_stream
pytest test/distributed/test_c10d_nccl.py -v -k wait_stream
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156883
Approved by: https://github.com/kwen2501, https://github.com/fduwjj
2025-07-08 23:55:46 +00:00

39 lines
863 B
C++

#pragma once
#include <chrono>
#include <memory>
#include <c10/util/Registry.h>
namespace c10d::cuda {
enum StreamBlockStatus : int32_t {
UNKNOWN = 0,
RUNNING = 1,
TIMED_OUT = 2,
ABORTED = 3,
};
/*
StreamBlock implements a baton that will block a the active CUDA stream
until aborted by the main process.
*/
class TORCH_API StreamBlock {
public:
virtual ~StreamBlock() = default;
virtual void abort() = 0;
virtual StreamBlockStatus status() = 0;
};
std::unique_ptr<StreamBlock> block_stream(std::chrono::milliseconds timeout);
// Declare a registry so we can call the CUDA StreamBlock API from CPU only code
// (i.e. ProcessGroup/Work objects in libtorch_cpu).
// The implementation lives defined in StreamBlock.cu.
TORCH_DECLARE_REGISTRY(
StreamBlockRegistry,
StreamBlock,
std::chrono::milliseconds);
} // namespace c10d::cuda