mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Some checks failed
quantization-periodic / get-default-label-prefix (push) Has been cancelled
quantization-periodic / periodic-quantization-build (push) Has been cancelled
quantization-periodic / periodic-test-quantization (push) Has been cancelled
weekly / update-commit-hash (push) Has been cancelled
weekly / update-slow-tests (push) Has been cancelled
docker-builds / get-label-type (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-aarch64-py3.10-gcc11, linux.arm64.m7g.4xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks, linux.arm64.m7g.4xlarge, 600) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-linter, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-linter, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang12-executorch, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang12-onnx, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang18-asan, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-gcc11-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.10-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.10-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.12-halide, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.12-triton-cpu, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.13-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.14-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-rocm-n-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-rocm-n-py3-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-1-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-noble-riscv64-py3.12-gcc14, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-noble-rocm-n-py3, linux.12xlarge) (push) Has been cancelled
ossf-scorecard / Scorecards analysis (push) Has been cancelled
Close nonexistent disable issues / close-nonexistent-disable-issues (push) Has been cancelled
Index PyTorch Tests for Target Determination / get-label-type (push) Has been cancelled
nightly / get-label-type (push) Has been cancelled
nightly / update-commit-hashes (main, .ci/docker/ci_commit_pins, triton, triton-lang) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, audio, pytorch) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, vision, pytorch) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, vllm, vllm-project) (push) Has been cancelled
Index PyTorch Tests for Target Determination / index (push) Has been cancelled
nightly / Link checks (push) Has been cancelled
nightly / docs build (push) Has been cancelled
nightly / docs push (push) Has been cancelled
ghstack-source-id: 59641aa32dc6fd027abf3276017432b693aa71f8 Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/165065 Fixes #ISSUE_NUMBER Opening a new PR for codev Pull Request resolved: https://github.com/pytorch/pytorch/pull/166136 Approved by: https://github.com/ngimel
57 lines
2.0 KiB
C++
57 lines
2.0 KiB
C++
#pragma once
|
|
|
|
// The InputBuffer class accumulates a list of Variables for use by a
|
|
// function. It implements logic to avoid modifying the passed
|
|
// values in-place (adding an input twice will accumulate the result).
|
|
// This behaviour is needed and used only in backward graphs.
|
|
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include <c10/core/Stream.h>
|
|
#include <torch/csrc/autograd/variable.h>
|
|
#include <optional>
|
|
|
|
namespace torch::autograd {
|
|
|
|
struct InputBuffer {
|
|
explicit InputBuffer(size_t size)
|
|
: buffer(size),
|
|
opt_accum_streams(size),
|
|
ready_events(size),
|
|
ready_streams(size) {}
|
|
InputBuffer(const InputBuffer& other) = delete;
|
|
InputBuffer(InputBuffer&& other) = default;
|
|
explicit InputBuffer(variable_list&& inputs) : buffer(std::move(inputs)) {}
|
|
InputBuffer& operator=(InputBuffer&& other) = default;
|
|
|
|
// Accumulates the variable at a specified index.
|
|
// The optional CUDA streams determine which stream the accumulation
|
|
// is run on and how the addition is synchronized.
|
|
TORCH_API void add(
|
|
size_t pos,
|
|
Variable&& var,
|
|
const std::optional<c10::Stream>& opt_producer_stream,
|
|
const std::optional<c10::Stream>& opt_consumer_stream,
|
|
Node* fn);
|
|
|
|
Variable operator[](size_t pos) {
|
|
return buffer[pos];
|
|
}
|
|
|
|
// Returns the inputs as a list of variables. Destroys given InputBuffer.
|
|
static std::vector<Variable> variables(InputBuffer&& g);
|
|
|
|
std::vector<Variable> buffer;
|
|
// The stream used for accumulation when a variable is used multiple times.
|
|
std::vector<std::optional<c10::Stream>> opt_accum_streams;
|
|
// The events you need to wait for to ensure the corresponding buffers
|
|
// are ready. The events are updated as we accumulate into the buffer.
|
|
std::vector<std::optional<c10::Event>> ready_events;
|
|
// The streams corresponding to the events above. This is only used to
|
|
// check if more synchronization is needed or not.
|
|
std::vector<std::optional<c10::Stream>> ready_streams;
|
|
};
|
|
|
|
} // namespace torch::autograd
|