pytorch/torch/csrc/distributed/c10d/cuda/StreamBlock.cuh
Tristan Rice 4366610f5a [c10d] block_current_stream: correctness fixes (#158757)
This fixes a number of issues that were present in https://github.com/pytorch/pytorch/pull/156883 as pointed out by @ngimel

Test plan:

Expanded tests to cover use after free behavior + non-default stream

```
pytest test/distributed/test_c10d_pypg.py -v -k block_current_stream
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158757
Approved by: https://github.com/ngimel
2025-07-21 22:23:44 +00:00

31 lines
631 B
Plaintext

#pragma once
#include <chrono>
#include <ATen/core/Tensor.h>
#include <torch/csrc/distributed/c10d/cuda/StreamBlock.hpp>
namespace c10d::cuda::detail {
class StreamBlock : public ::c10d::cuda::StreamBlock {
public:
StreamBlock(std::chrono::milliseconds timeout);
void abort() override {
std::atomic_thread_fence(std::memory_order_seq_cst);
comm_[0] = 1;
}
StreamBlockStatus status() override {
return static_cast<StreamBlockStatus>(comm_[1].item<int32_t>());
}
private:
// (abort, cycles)
const at::Tensor comm_;
const std::chrono::milliseconds timeout_;
};
} // namespace c10d::cuda::detail