pytorch/torch/csrc/autograd/input_buffer.h
soulitzer b3861ac8e7
Some checks failed
quantization-periodic / get-default-label-prefix (push) Has been cancelled
quantization-periodic / periodic-quantization-build (push) Has been cancelled
quantization-periodic / periodic-test-quantization (push) Has been cancelled
weekly / update-commit-hash (push) Has been cancelled
weekly / update-slow-tests (push) Has been cancelled
docker-builds / get-label-type (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-aarch64-py3.10-gcc11, linux.arm64.m7g.4xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks, linux.arm64.m7g.4xlarge, 600) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-linter, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-linter, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang12-executorch, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang12-onnx, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang18-asan, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-gcc11-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.10-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.10-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.12-halide, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.12-triton-cpu, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.13-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.14-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-rocm-n-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-rocm-n-py3-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-1-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-noble-riscv64-py3.12-gcc14, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-noble-rocm-n-py3, linux.12xlarge) (push) Has been cancelled
ossf-scorecard / Scorecards analysis (push) Has been cancelled
Close nonexistent disable issues / close-nonexistent-disable-issues (push) Has been cancelled
Index PyTorch Tests for Target Determination / get-label-type (push) Has been cancelled
nightly / get-label-type (push) Has been cancelled
nightly / update-commit-hashes (main, .ci/docker/ci_commit_pins, triton, triton-lang) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, audio, pytorch) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, vision, pytorch) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, vllm, vllm-project) (push) Has been cancelled
Index PyTorch Tests for Target Determination / index (push) Has been cancelled
nightly / Link checks (push) Has been cancelled
nightly / docs build (push) Has been cancelled
nightly / docs push (push) Has been cancelled
[reland] Warn if AccumulateGrad stream does not match producer node stream (#166136)
ghstack-source-id: 59641aa32dc6fd027abf3276017432b693aa71f8
Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/165065

Fixes #ISSUE_NUMBER

Opening a new PR for codev

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166136
Approved by: https://github.com/ngimel
2025-11-01 12:33:48 +00:00

57 lines
2.0 KiB
C++

#pragma once
// The InputBuffer class accumulates a list of Variables for use by a
// function. It implements logic to avoid modifying the passed
// values in-place (adding an input twice will accumulate the result).
// This behaviour is needed and used only in backward graphs.
#include <utility>
#include <vector>
#include <c10/core/Stream.h>
#include <torch/csrc/autograd/variable.h>
#include <optional>
namespace torch::autograd {
struct InputBuffer {
explicit InputBuffer(size_t size)
: buffer(size),
opt_accum_streams(size),
ready_events(size),
ready_streams(size) {}
InputBuffer(const InputBuffer& other) = delete;
InputBuffer(InputBuffer&& other) = default;
explicit InputBuffer(variable_list&& inputs) : buffer(std::move(inputs)) {}
InputBuffer& operator=(InputBuffer&& other) = default;
// Accumulates the variable at a specified index.
// The optional CUDA streams determine which stream the accumulation
// is run on and how the addition is synchronized.
TORCH_API void add(
size_t pos,
Variable&& var,
const std::optional<c10::Stream>& opt_producer_stream,
const std::optional<c10::Stream>& opt_consumer_stream,
Node* fn);
Variable operator[](size_t pos) {
return buffer[pos];
}
// Returns the inputs as a list of variables. Destroys given InputBuffer.
static std::vector<Variable> variables(InputBuffer&& g);
std::vector<Variable> buffer;
// The stream used for accumulation when a variable is used multiple times.
std::vector<std::optional<c10::Stream>> opt_accum_streams;
// The events you need to wait for to ensure the corresponding buffers
// are ready. The events are updated as we accumulate into the buffer.
std::vector<std::optional<c10::Event>> ready_events;
// The streams corresponding to the events above. This is only used to
// check if more synchronization is needed or not.
std::vector<std::optional<c10::Stream>> ready_streams;
};
} // namespace torch::autograd