mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This fixes a number of issues that were present in https://github.com/pytorch/pytorch/pull/156883 as pointed out by @ngimel Test plan: Expanded tests to cover use after free behavior + non-default stream ``` pytest test/distributed/test_c10d_pypg.py -v -k block_current_stream ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158757 Approved by: https://github.com/ngimel
31 lines
631 B
Plaintext
31 lines
631 B
Plaintext
#pragma once
|
|
|
|
#include <chrono>
|
|
|
|
#include <ATen/core/Tensor.h>
|
|
|
|
#include <torch/csrc/distributed/c10d/cuda/StreamBlock.hpp>
|
|
|
|
namespace c10d::cuda::detail {
|
|
|
|
class StreamBlock : public ::c10d::cuda::StreamBlock {
|
|
public:
|
|
StreamBlock(std::chrono::milliseconds timeout);
|
|
|
|
void abort() override {
|
|
std::atomic_thread_fence(std::memory_order_seq_cst);
|
|
comm_[0] = 1;
|
|
}
|
|
|
|
StreamBlockStatus status() override {
|
|
return static_cast<StreamBlockStatus>(comm_[1].item<int32_t>());
|
|
}
|
|
|
|
private:
|
|
// (abort, cycles)
|
|
const at::Tensor comm_;
|
|
const std::chrono::milliseconds timeout_;
|
|
};
|
|
|
|
} // namespace c10d::cuda::detail
|