pytorch/torch/csrc/autograd/input_buffer.cpp
richard 382ef1fda7 Autograd graphtask trim unnecessary edges (#82544)
### Introduction
<!-- What did you change and why was it needed? -->

Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine.

For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`,  only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training).

The figure attached shows the autograd graph of the following code snippet:
```py
y = torch.nn.functional.linear(x, weight, bias)
y = y.pow(2)
# first order derivative
y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True)
# first order derivative
y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True)
```
The path with  is not needed when calculating derivatives.

<img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png">

### Issue
<!-- Link to Issue ticket or RFP -->
Related issue: https://github.com/pytorch/pytorch/issues/56500

### Method
When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated.

Following https://github.com/pytorch/pytorch/issues/56500#issuecomment-825694656, this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls.

### Benchmark
Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99

Benchmark result:
6 hidden layers, batch size 10000, on A100

FP32 result
| hessian benchmark             | FP32 (before) | FP32 (After)      | FP32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 55.658 ms     | 29.392 ms (1.90X) | 29.547 ms (1.90X)       |
| Linear + ReLU (with backward) | 81.173 ms     | 54.917 ms (1.47X) | 68.988 ms (1.18X)       |

TF32 result
| hessian benchmark             | TF32 (before) | TF32 (after)      | TF32 (Functorch v0.1.1) |
| ----------------------------- | ------------- | ----------------- | ----------------------- |
| Linear + ReLU (no backward)   | 19.801 ms     | 11.259 ms (1.76X) | 10.754 ms (1.84X)       |
| Linear + ReLU (with backward) | 29.167 ms     | 20.466 ms (1.42X) | 22.784 ms (1.28X)       |

For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, https://github.com/pytorch/functorch/issues/989, so we are using v0.1.1 for benchmark)

@zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR?

### Testing
<!-- How did you test your change? -->

- [x] we need to figure out a way for unittest

### Thanks
Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/)

cc @zasdfgbnm @albanD
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82544
Approved by: https://github.com/soulitzer
2022-08-11 18:50:09 +00:00

218 lines
7.6 KiB
C++

#include <torch/csrc/autograd/input_buffer.h>
#include <ATen/BatchedTensorImpl.h>
#include <ATen/SparseCsrTensorUtils.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/TensorOperators.h>
#include <c10/core/DeviceGuard.h>
#include <c10/core/Event.h>
#include <c10/core/StreamGuard.h>
#include <c10/util/Optional.h>
#include <cstddef>
#include <utility>
#include <vector>
namespace torch {
namespace autograd {
namespace {
// look what you made me do >.<
// Divergent paths for per-Impl stream recording that leak implementation
// details of the impls should not be needed here.
// See https://github.com/pytorch/pytorch/issues/60306
// TODO: clean this up when https://github.com/pytorch/pytorch/issues/60306 is
// improved
void record_stream_any_impl(Variable& var, c10::Stream& stream) {
const auto guard = c10::impl::VirtualGuardImpl(c10::DeviceType::CUDA);
if (C10_UNLIKELY(at::isBatchedTensor(var))) {
auto* impl = at::maybeGetBatchedImpl(var);
if (impl) {
guard.recordDataPtrOnStream(impl->value().storage().data_ptr(), stream);
} else {
TORCH_INTERNAL_ASSERT(false, "Expected batched tensor");
}
} else {
switch (var.layout()) {
case c10::kSparseCsr:
case c10::kSparseCsc:
case c10::kSparseBsr:
case c10::kSparseBsc: {
auto* impl = at::sparse_csr::get_sparse_csr_impl(var);
guard.recordDataPtrOnStream(
impl->values().storage().data_ptr(), stream);
guard.recordDataPtrOnStream(
impl->compressed_indices().storage().data_ptr(), stream);
guard.recordDataPtrOnStream(
impl->plain_indices().storage().data_ptr(), stream);
break;
}
case c10::kSparse: {
auto* impl = at::sparse::get_sparse_impl(var);
guard.recordDataPtrOnStream(
impl->values().storage().data_ptr(), stream);
guard.recordDataPtrOnStream(
impl->indices().storage().data_ptr(), stream);
break;
}
case c10::kStrided:
guard.recordDataPtrOnStream(var.storage().data_ptr(), stream);
break;
default:
TORCH_INTERNAL_ASSERT(
false, "Unknown layout in record_stream_any_impl");
}
}
}
} // anonymous namespace
static void accumulate(
std::vector<Variable>& buffer,
const size_t pos,
Variable&& var) {
TORCH_INTERNAL_ASSERT(pos < buffer.size());
auto& old_var = buffer[pos];
// ATen doesn't route sparse additions correctly...
// do dense + sparse in-place if possible
if (old_var.is_sparse()) {
// It is safe to change the Tensor inplace if the Tensor is only used in
// this buffer (this could be the gradient passed by the user) and that no
// other Tensor is using the same storage.
if (!var.is_sparse() && var.is_contiguous() && var.use_count() == 1 &&
var.storage().use_count() == 1) {
buffer[pos] = var.add_(old_var);
} else {
buffer[pos] = var + old_var;
}
} else {
if (var.is_sparse() && !old_var.is_sparse() && old_var.is_contiguous() &&
old_var.use_count() == 1 && old_var.storage().use_count() == 1) {
buffer[pos] = old_var.add_(var);
} else {
buffer[pos] = old_var + var;
}
}
}
void InputBuffer::add(
size_t pos,
Variable&& var,
const c10::optional<c10::Stream>& opt_producer_stream,
const c10::optional<c10::Stream>& opt_consumer_stream) {
TORCH_INTERNAL_ASSERT(pos < buffer.size());
if (!var.defined()) {
return;
}
// Switches to accumulate device
// The device (and stream) chosen for accumulation is:
// (1) var is not a CUDA variable. Accumulation happens on var's device.
// (2) var is a CUDA variable and it, the consumer, and the producer share
// the same device:
// (2a) Uses the consumer's stream as the accumulation stream
// (2b) Syncs the accumulation stream with the producer's stream (if
// different) (2c) Accumulates.
// (3) var is a CUDA variable and it shares a device with the consumer but
// not the producer:
// (3a) Uses the consumer's stream as the accumulation stream
// (3b) Syncs the accumulation stream with the consumer device's default
// stream (3c) Accumulates.
// (4) var is a CUDA variable and it shares a device with the producer but
// not the consumer:
// (4a) Uses the producer device's default stream as the accumulation
// stream (4b) Syncs the accumulation stream with the the producer's
// stream (4c) Accumulates.
// (5) var is a CUDA variable and it does not share a device with the
// consumer or producer.
// Accumulation happens on the var device's default stream.
TORCH_INTERNAL_ASSERT(device_of(var));
c10::optional<c10::Stream> opt_accumulate_stream = c10::nullopt;
if (device_of(var)->is_cuda()) {
const auto on_producer =
opt_producer_stream && device_of(var) == opt_producer_stream->device();
const auto on_consumer =
opt_consumer_stream && device_of(var) == opt_consumer_stream->device();
if (on_producer && on_consumer) {
// (2a)
opt_accumulate_stream = opt_consumer_stream;
if (opt_accumulate_stream != opt_producer_stream) {
// (2b)
auto event = c10::Event{c10::DeviceType::CUDA};
event.record(*opt_producer_stream);
opt_accumulate_stream->wait(event);
record_stream_any_impl(var, *opt_accumulate_stream);
}
} else {
c10::optional<c10::Stream> opt_sync_stream = c10::nullopt;
const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
if (on_consumer && !on_producer) {
// (3a)
opt_accumulate_stream = opt_consumer_stream;
opt_sync_stream = guard.getDefaultStream(opt_consumer_stream->device());
} else if (on_producer && !on_consumer) {
// (4a)
opt_accumulate_stream =
guard.getDefaultStream(opt_producer_stream->device());
opt_sync_stream = opt_producer_stream;
} else {
// (5)
opt_accumulate_stream = guard.getDefaultStream(*device_of(var));
}
if (opt_sync_stream && (opt_accumulate_stream != opt_sync_stream)) {
// (3b), (4b)
c10::OptionalDeviceGuard device_guard{opt_sync_stream->device()};
auto event = c10::Event{c10::DeviceType::CUDA};
event.record(*opt_sync_stream);
opt_accumulate_stream->wait(event);
const auto guard = c10::impl::VirtualGuardImpl(c10::DeviceType::CUDA);
record_stream_any_impl(var, *opt_accumulate_stream);
}
}
}
auto& old_var = buffer[pos];
if (!old_var.defined()) {
buffer[pos] = std::move(var);
} else {
if (opt_accumulate_stream) {
c10::OptionalStreamGuard stream_guard{opt_accumulate_stream};
accumulate(buffer, pos, std::move(var));
} else {
// (1) non-CUDA variable
// Accumulation happens on variable's device
c10::OptionalDeviceGuard device_guard{device_of(var)};
accumulate(buffer, pos, std::move(var));
}
}
}
auto InputBuffer::device() const -> at::Device {
// Since we pick the first non-CPU tensor, this won't work with
// mixed device-type operations (e.g., an op that is both CUDA
// and XLA). This is *incredibly* unlikely, so we don't worry
// about it.
for (auto& var : buffer) {
if (var.defined()) {
auto device = var.device();
if (device.type() != at::kCPU) {
return device;
}
}
}
// Only report to the CPU thread if there really were no tensors
// from other devices.
return at::kCPU;
}
auto InputBuffer::variables(InputBuffer&& g) -> std::vector<Variable> {
std::vector<Variable> result = std::move(g.buffer);
return result;
}
} // namespace autograd
} // namespace torch