mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Some checks failed
quantization-periodic / get-default-label-prefix (push) Has been cancelled
quantization-periodic / periodic-quantization-build (push) Has been cancelled
quantization-periodic / periodic-test-quantization (push) Has been cancelled
weekly / update-commit-hash (push) Has been cancelled
weekly / update-slow-tests (push) Has been cancelled
docker-builds / get-label-type (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-aarch64-py3.10-gcc11, linux.arm64.m7g.4xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks, linux.arm64.m7g.4xlarge, 600) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-linter, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-linter, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang12-executorch, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang12-onnx, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-clang18-asan, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3-gcc11-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.10-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.10-gcc11, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.12-halide, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.12-triton-cpu, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.13-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-py3.14-clang12, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-rocm-n-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-rocm-n-py3-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-1-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-py3, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-noble-riscv64-py3.12-gcc14, linux.12xlarge) (push) Has been cancelled
docker-builds / docker-build (pytorch-linux-noble-rocm-n-py3, linux.12xlarge) (push) Has been cancelled
ossf-scorecard / Scorecards analysis (push) Has been cancelled
Close nonexistent disable issues / close-nonexistent-disable-issues (push) Has been cancelled
Index PyTorch Tests for Target Determination / get-label-type (push) Has been cancelled
nightly / get-label-type (push) Has been cancelled
nightly / update-commit-hashes (main, .ci/docker/ci_commit_pins, triton, triton-lang) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, audio, pytorch) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, vision, pytorch) (push) Has been cancelled
nightly / update-commit-hashes (main, .github/ci_commit_pins, vllm, vllm-project) (push) Has been cancelled
Index PyTorch Tests for Target Determination / index (push) Has been cancelled
nightly / Link checks (push) Has been cancelled
nightly / docs build (push) Has been cancelled
nightly / docs push (push) Has been cancelled
ghstack-source-id: 59641aa32dc6fd027abf3276017432b693aa71f8 Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/165065 Fixes #ISSUE_NUMBER Opening a new PR for codev Pull Request resolved: https://github.com/pytorch/pytorch/pull/166136 Approved by: https://github.com/ngimel
327 lines
12 KiB
C++
327 lines
12 KiB
C++
#include <torch/csrc/autograd/functions/accumulate_grad.h>
|
|
#include <torch/csrc/autograd/input_buffer.h>
|
|
|
|
#include <ATen/CachedTensorUtils.h>
|
|
#include <ATen/LegacyBatchedTensorImpl.h>
|
|
#include <ATen/SparseCsrTensorUtils.h>
|
|
#include <ATen/TensorOperators.h>
|
|
#include <ATen/TensorSubclassLikeUtils.h>
|
|
#include <ATen/core/grad_mode.h>
|
|
#include <ATen/native/SparseTensorUtils.h>
|
|
|
|
#include <c10/core/DeviceGuard.h>
|
|
#include <c10/core/Event.h>
|
|
#include <c10/core/StreamGuard.h>
|
|
#include <c10/util/Logging.h>
|
|
#include <optional>
|
|
|
|
#include <cstddef>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace torch::autograd {
|
|
|
|
namespace {
|
|
// look what you made me do >.<
|
|
// Divergent paths for per-Impl stream recording that leak implementation
|
|
// details of the impls should not be needed here.
|
|
// See https://github.com/pytorch/pytorch/issues/60306
|
|
// TODO: clean this up when https://github.com/pytorch/pytorch/issues/60306 is
|
|
// improved
|
|
void record_stream_any_impl(Variable& var, const c10::Stream& stream) {
|
|
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
|
|
|
if (stream.device_index() != var.device().index()) {
|
|
return;
|
|
}
|
|
|
|
const auto guard = c10::impl::VirtualGuardImpl(device_of(var).value().type());
|
|
|
|
if (C10_UNLIKELY(at::isBatchedTensor(var))) {
|
|
auto* impl = at::maybeGetBatchedImpl(var);
|
|
if (impl) {
|
|
guard.recordDataPtrOnStream(impl->value().storage().data_ptr(), stream);
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(false, "Expected batched tensor");
|
|
}
|
|
} else {
|
|
switch (var.layout()) {
|
|
case c10::kSparseCsr:
|
|
case c10::kSparseCsc:
|
|
case c10::kSparseBsr:
|
|
case c10::kSparseBsc: {
|
|
auto* impl = at::sparse_csr::get_sparse_csr_impl(var);
|
|
guard.recordDataPtrOnStream(
|
|
impl->values().storage().data_ptr(), stream);
|
|
guard.recordDataPtrOnStream(
|
|
impl->compressed_indices().storage().data_ptr(), stream);
|
|
guard.recordDataPtrOnStream(
|
|
impl->plain_indices().storage().data_ptr(), stream);
|
|
break;
|
|
}
|
|
case c10::kSparse: {
|
|
auto* impl = at::sparse::get_sparse_impl(var);
|
|
guard.recordDataPtrOnStream(
|
|
impl->values().storage().data_ptr(), stream);
|
|
guard.recordDataPtrOnStream(
|
|
impl->indices().storage().data_ptr(), stream);
|
|
break;
|
|
}
|
|
case c10::kStrided:
|
|
guard.recordDataPtrOnStream(var.storage().data_ptr(), stream);
|
|
break;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Unknown layout in record_stream_any_impl");
|
|
}
|
|
}
|
|
}
|
|
|
|
bool can_accumulate_inplace(const Variable& v) {
|
|
return (
|
|
// `v` is a "vanilla" Tensor
|
|
!(at::isTensorSubclassLike(v) || v._is_zerotensor() || v.is_nested()) &&
|
|
|
|
// with a favorable memory layout
|
|
v.is_non_overlapping_and_dense() &&
|
|
|
|
// and we hold the last reference
|
|
at::caching::adjusted_use_count(v) == 1 && v.has_storage() &&
|
|
v.storage().use_count() == 1);
|
|
}
|
|
} // anonymous namespace
|
|
|
|
static void accumulate(
|
|
std::vector<Variable>& buffer,
|
|
const size_t pos,
|
|
Variable&& var) {
|
|
TORCH_INTERNAL_ASSERT(pos < buffer.size());
|
|
auto& old_var = buffer[pos];
|
|
// If we hold the last reference to `old_var` AND its storage we will try to
|
|
// repurpose it to store the output. (Or, if `old_var` is sparse then `var`
|
|
// becomes the candidate output Tensor.) We only do this if:
|
|
// 1) GradMode is disabled since Autograd has special handling for inplace
|
|
// mutation which we don't want to trigger.
|
|
//
|
|
// 2) We hold the last reference.
|
|
// (Both `.use_count` and `.storage().use_count()` are one)
|
|
//
|
|
// 3) The candidate tensor is a contiguous, non-overlapping, dense, and
|
|
// otherwise stock standard Tensor.
|
|
//
|
|
// 4) The candidate is mutable. Currently only ZeroTensors are immutable.
|
|
//
|
|
// 5) The other Tensor is not a Tensor subclass (except sparse), since
|
|
// it's hard to predict the semantics of arbitrary subclass behavior.
|
|
|
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
|
if (at::GradMode::is_enabled()) {
|
|
buffer[pos] = old_var + var;
|
|
} else if (
|
|
// ATen doesn't route sparse additions correctly...
|
|
old_var.is_sparse() || old_var.is_sparse_csr()) {
|
|
if (can_accumulate_inplace(var)) {
|
|
buffer[pos] = var.add_(old_var);
|
|
} else {
|
|
buffer[pos] = var + old_var;
|
|
}
|
|
} else if (
|
|
can_accumulate_inplace(old_var) && !at::isTensorSubclassLike(var)) {
|
|
buffer[pos] = old_var.add_(var);
|
|
} else {
|
|
buffer[pos] = old_var + var;
|
|
}
|
|
}
|
|
|
|
// Note: [Stream sync contract when dealing with multi-deviced-ness]
|
|
//
|
|
// An operator can deal with multiple devices, e.g. if it does a device
|
|
// transfer, etc. However, for the purpose of stream synchronization, the engine
|
|
// is only aware of single canonical device/stream for each autograd Node.
|
|
//
|
|
// For the proper synchronization, the Node author should make sure of the
|
|
// following:
|
|
//
|
|
// 1) A node consuming a gradient should wait on the canonical stream before
|
|
// using it.
|
|
// 2) A node producing a gradient should have it ready on the canonical
|
|
// stream during node execution.
|
|
//
|
|
|
|
// Note: [Autograd Producer-Consumer Stream Syncs]
|
|
//
|
|
// The producer-consumer stream syncs are partially handled in this method
|
|
// and partially handled in the engine prior to the consumer's execution.
|
|
// The logic here is mainly responsible for handling the synchronization needed
|
|
// for accumulation and recording the event that the consumer should wait on
|
|
// later. The corresponding wait and record_stream happens in the engine.
|
|
//
|
|
// First producer
|
|
// ==============
|
|
// There are several things we need to do upon seeing the first producer:
|
|
// 1) Determine the accumulation stream (which may or may not be used):
|
|
// case A) var's device matches consumer node's canonical device
|
|
// (The producer node's canonical device may or may not match)
|
|
// -> accumulator stream = consumer stream
|
|
// case B) var's device matches producer node's canonical device
|
|
// and does not match consumer node's canonical device
|
|
// -> accumulator stream = producer stream
|
|
// case C) var device matches neither
|
|
// -> accumulator stream = var device's current stream
|
|
// See Note [Stream sync contract when dealing with
|
|
// multi-deviced-ness]
|
|
// 2) Because we are the first producer, there's no accumulation necessary.
|
|
// Just move var into the buffer.
|
|
// 3) Update the ready_events and streams for the current position.**
|
|
// ready_events are events you need to wait for to ensure the corresponding
|
|
// buffers are ready. The events are updated as we accumulate into the
|
|
// buffer.
|
|
//
|
|
// Nth producer
|
|
// ============
|
|
// 1) Synchronize for accumulation. Accumulation operates on both the new
|
|
// incoming gradient and the existing gradient in the buffer.
|
|
// (i) wait stream and (ii) record stream to make sure both are ready to be
|
|
// used on the accumulation stream.
|
|
// 2) Accumulate on the accumulation stream
|
|
// 3) Update the ready event and stream for the current position.**
|
|
//
|
|
// **As an optimization, we avoid creating and recording an event if we
|
|
// know that we won't need to wait on it, saving on the order of microseconds.
|
|
//
|
|
void InputBuffer::add(
|
|
size_t pos,
|
|
Variable&& var,
|
|
const std::optional<c10::Stream>& opt_producer_stream_,
|
|
const std::optional<c10::Stream>& opt_consumer_stream_,
|
|
Node* fn) {
|
|
TORCH_INTERNAL_ASSERT(pos < buffer.size());
|
|
|
|
if (!var.defined()) {
|
|
return;
|
|
}
|
|
const auto device = var.device();
|
|
const auto device_type = device.type();
|
|
bool is_accelerator = at::accelerator::isAccelerator(device.type());
|
|
//
|
|
// Non-accelerator case
|
|
//
|
|
if (!is_accelerator) {
|
|
if (!buffer[pos].defined()) {
|
|
buffer[pos] = std::move(var);
|
|
} else {
|
|
c10::OptionalDeviceGuard device_guard{device};
|
|
accumulate(buffer, pos, std::move(var));
|
|
}
|
|
return;
|
|
}
|
|
// Handle the case where var is on an accelerator but producer node has no
|
|
// canonical stream, e.g. this can happen if forward is DtoH
|
|
const std::optional<c10::Stream>& opt_producer_stream =
|
|
(opt_producer_stream_.has_value()
|
|
? opt_producer_stream_
|
|
: std::optional<c10::Stream>(
|
|
at::accelerator::getCurrentStream(device.index())));
|
|
|
|
// opt_consumer_stream is always non-null when is_accelerator is true
|
|
// when InputBuffer is used in the engine. InputBuffer is also called
|
|
// elsewhere however! (e.g. other engine implementations)
|
|
const std::optional<c10::Stream>& opt_consumer_stream =
|
|
(opt_consumer_stream_.has_value()
|
|
? opt_consumer_stream_
|
|
: std::optional<c10::Stream>(
|
|
at::accelerator::getCurrentStream(device.index())));
|
|
|
|
TORCH_INTERNAL_ASSERT(opt_consumer_stream && opt_producer_stream);
|
|
|
|
if (*opt_consumer_stream != *opt_producer_stream &&
|
|
dynamic_cast<AccumulateGrad*>(fn) &&
|
|
at::globalContext().warnOnAccumulateGradStreamMismatch()) {
|
|
TORCH_WARN_ONCE(
|
|
"The AccumulateGrad node's stream does not match the stream of the node that produced "
|
|
"the incoming gradient. This may incur unnecessary synchronization and break CUDA graph "
|
|
"capture if the AccumulateGrad node's stream is the default stream. This mismatch is "
|
|
"caused by an AccumulateGrad node created prior to the current iteration being kept alive. "
|
|
"This can happen if the autograd graph is still being kept alive by tensors such as the "
|
|
"loss, or if you are using DDP, which will stash a reference to the node. To resolve the "
|
|
"mismatch, delete all references to the autograd graph or ensure that DDP initialization is "
|
|
"performed under the same stream as subsequent forwards. If the mismatch is intentional, "
|
|
"you can use torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False) to suppress this "
|
|
"warning.");
|
|
}
|
|
// See Note: [Autograd Producer-Consumer Stream Syncs]
|
|
if (!opt_accum_streams[pos].has_value()) {
|
|
// [ First producer ]
|
|
TORCH_INTERNAL_ASSERT(!buffer[pos].defined());
|
|
// 1)
|
|
if (opt_consumer_stream->device() == device) {
|
|
// Case A
|
|
opt_accum_streams[pos] = opt_consumer_stream;
|
|
if (*opt_consumer_stream != *opt_producer_stream) {
|
|
// We will end up doing record_stream on the accumulation stream
|
|
// (which is the consumer stream) later, but we also need to do
|
|
// it here in case we don't end up accumulating.
|
|
record_stream_any_impl(var, *opt_consumer_stream);
|
|
}
|
|
} else if (opt_producer_stream->device() == device) {
|
|
// Case B
|
|
opt_accum_streams[pos] = opt_producer_stream;
|
|
} else {
|
|
// Case C
|
|
opt_accum_streams[pos] =
|
|
at::accelerator::getCurrentStream(device.index());
|
|
}
|
|
// 2)
|
|
buffer[pos] = std::move(var);
|
|
// 3)
|
|
auto& opt_accum_stream = opt_accum_streams[pos];
|
|
TORCH_INTERNAL_ASSERT(opt_accum_stream.has_value());
|
|
if (*opt_consumer_stream != *opt_producer_stream ||
|
|
*opt_accum_stream != *opt_producer_stream) {
|
|
// Either the consumer or accum stream waits for the producer
|
|
// stream depending on whether accumulation is needed.
|
|
auto event = c10::Event{device_type};
|
|
event.record(*opt_producer_stream);
|
|
ready_events[pos] = std::move(event);
|
|
}
|
|
ready_streams[pos] = opt_producer_stream;
|
|
} else {
|
|
// [ Nth producer ]
|
|
auto accum_stream = opt_accum_streams[pos];
|
|
auto& ready_event = ready_events[pos];
|
|
auto& ready_stream = ready_streams[pos];
|
|
TORCH_INTERNAL_ASSERT(accum_stream && ready_stream);
|
|
// 1)
|
|
if (*accum_stream != *opt_producer_stream) {
|
|
auto event = c10::Event{device_type};
|
|
event.record(*opt_producer_stream);
|
|
accum_stream->wait(event);
|
|
record_stream_any_impl(var, *accum_stream);
|
|
}
|
|
if (*accum_stream != *ready_stream) {
|
|
TORCH_INTERNAL_ASSERT(ready_event);
|
|
accum_stream->wait(*ready_event);
|
|
// This is redundant for case A, but needed for case C
|
|
record_stream_any_impl(buffer[pos], *accum_stream);
|
|
}
|
|
// 2)
|
|
c10::OptionalStreamGuard stream_guard{accum_stream};
|
|
accumulate(buffer, pos, std::move(var));
|
|
// 3)
|
|
if (*opt_consumer_stream != *accum_stream) {
|
|
// Only the consumer stream needs to wait for this event
|
|
auto event = c10::Event{device_type};
|
|
event.record(*accum_stream);
|
|
ready_events[pos] = std::move(event);
|
|
}
|
|
ready_streams[pos] = accum_stream;
|
|
}
|
|
}
|
|
|
|
auto InputBuffer::variables(InputBuffer&& g) -> std::vector<Variable> {
|
|
std::vector<Variable> result = std::move(g.buffer);
|
|
return result;
|
|
}
|
|
|
|
} // namespace torch::autograd
|