mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Often, we find ourselves looking at some long-running kernel or emit_nvtx range on an nvvp profile and trying to connect it to the offending line in a training script. If the op is in the forward pass that's easy: ops are enqueued explicitly from the Python side, so tracking it down with manual nvtx ranges supplemented by the built-in emit_nvtx ranges is straightforward. If the op is in the backward pass, it's much more difficult. From the Python side, all you can do is wrap loss.backward() in an nvtx range, and if you also use emit_nvtx, the automatic ranges provide only local information. Right now, the only consistent way to connect backward-pass kernels to their associated forward-pass lines of Python is to understand your script line by line, and know exactly where in the backward pass you are. This PR augments the existing nvtx machinery to bridge the gap between forward and backward, allowing connection of backward-pass Function apply calls to the forward-pass operations that required/created those Functions. The method is simple and surgical. During the forward pass, when running with emit_nvtx, the nvtx range for each function in VariableType is tagged with the current sequence number. During the backward pass, the nvtx range associated with each Function's operator() is tagged with that Function's stashed sequence number, which can be compared to "current sequence numbers" from the forward pass to locate the associated op. Double-backward is not a problem. If a backward pass with create_graph = True is underway, the relationship between backward and double-backward is conceptually the same as the relationship between forward and backward: The functions in VariableType still spit out current-sequence-number-tagged ranges, the Function objects they create still stash those sequence numbers, and in the eventual double-backward execution, their operator() ranges are still tagged with the stashed numbers, which can be compared to "current sequence numbers" from the backward pass. Minor caveats: - The sequence number is thread-local, and many VariableType functions (specifically, those without a derivative explicitly defined in derivatives.yaml) don't create an associated function object (instead delegating that to sub-functions further down the call chain, perhaps called from within at::native functions that route back through VariableType by calling at::function_name). So the correspondence of stashed sequence numbers in Function operator() ranges with numbers in forward-pass ranges is not guaranteed to be 1 to 1. However, it's still a vast improvement over the current situation, and I don't think this issue should be a blocker. - Feel free to litigate my use of stringstream in profiler.cpp. I did it because it was easy and clean. If that's too big a hammer, let's figure out something more lightweight. Pull Request resolved: https://github.com/pytorch/pytorch/pull/10881 Differential Revision: D9833371 Pulled By: apaszke fbshipit-source-id: 1844f2e697117880ef5e31394e36e801d1de6088
136 lines
4.4 KiB
C++
136 lines
4.4 KiB
C++
#include "torch/csrc/autograd/function.h"
|
|
|
|
#include "torch/csrc/autograd/engine.h"
|
|
#include "torch/csrc/autograd/variable.h"
|
|
#include "torch/csrc/jit/ir.h"
|
|
|
|
#include <ATen/ATen.h>
|
|
|
|
#include <algorithm>
|
|
#include <cstdint>
|
|
#include <memory>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
#include <deque>
|
|
|
|
namespace torch { namespace autograd {
|
|
|
|
/// Monotonically incrementing (thread local!) counter to supply sequence
|
|
/// numbers.
|
|
thread_local uint64_t Function_next_sequence_nr_ = 0;
|
|
|
|
uint64_t Function::peek_at_next_sequence_nr() {
|
|
return Function_next_sequence_nr_;
|
|
}
|
|
|
|
uint64_t& Function::get_next_sequence_nr() {
|
|
return Function_next_sequence_nr_;
|
|
}
|
|
|
|
auto Function::name() const -> std::string {
|
|
return at::demangle(typeid(*this).name());
|
|
}
|
|
|
|
AnomalyMetadata* Function::metadata() noexcept {
|
|
if (!anomaly_metadata_) {
|
|
anomaly_metadata_ = Engine::get_default_engine().make_anomaly_metadata();
|
|
}
|
|
return anomaly_metadata_.get();
|
|
}
|
|
|
|
/*
|
|
* Fix for #5534: prevent stack overflow on deletion of deep computation graph
|
|
*
|
|
* Sometimes one can end up with a very big computation graph of Functions
|
|
* and Edges. Each std::shared_ptr<Function> contains a list of Edge, and
|
|
* each Edge contains a std::shared_ptr<Function>. Deleting a
|
|
* std::shared_ptr<Function> can trigger the recursive deletion of other
|
|
* std::shared_ptr<Function>'s: this can stack overflow if the graph
|
|
* is deep enough. Here is an example of such a graph:
|
|
*
|
|
* shared_ptr<Function> -> Edge -> shared_ptr<Function> -> Edge -> ... -> shared_ptr<Function>
|
|
*
|
|
* The solution here is to use a custom deleter with each
|
|
* std::shared_ptr<Function>. The custom deleter keeps track of how many
|
|
* nested deleters it is in. When this number exceeds the maximum allowed
|
|
* depth, the Function* to be deleted are accumulated in a per-thread
|
|
* delete queue and handled by one of the deleters.
|
|
*
|
|
* Note that these custom deleters are NOT necessary for deleting PyFunction.
|
|
* This is because a THPFunction Python object owns a PyFunction that is in a
|
|
* computation graph. When Python objects get recursively destroyed, they
|
|
* are also queued into a delete list. This happens very early for them
|
|
* (at 50 deleters): https://github.com/python/cpython/blob/f320be77ffb73e3b9e7fc98c37b8df3975d84b40/Include/object.h#L1024-L1063
|
|
* so we don't need to worry about them.
|
|
*/
|
|
|
|
thread_local std::deque<Function*> deleteFunctionQueue;
|
|
thread_local size_t deleteFunctionRecursionDepth = 0;
|
|
|
|
/*
|
|
* If this number is set too high, a deep computation graph can still
|
|
* stack overflow. The procedure for setting this number was to
|
|
* 1) find the smallest value that would not guard against stack overflows
|
|
* on various machines
|
|
* 2) Take the minimum of all such values and subtract some leeway because
|
|
* the memory of these stack frames will probably grow as time passes.
|
|
* Testing on a few machines machines, the magic numbers were:
|
|
* - Mac OSX (Macbook Pro 15) : ~60000
|
|
* - A beefy Ubuntu 16.04 box : ~15000
|
|
* - Windows AWS instance (g3.4xlarge): variable. My two attempts at different
|
|
* times have gotten the following numbers: ~8300, 3669
|
|
*/
|
|
#ifdef _WIN32
|
|
size_t deleteFunctionMaxRecursionDepth = 3000;
|
|
#else
|
|
size_t deleteFunctionMaxRecursionDepth = 10000;
|
|
#endif
|
|
|
|
struct RecursionDepthCounter {
|
|
public:
|
|
explicit RecursionDepthCounter() {
|
|
++deleteFunctionRecursionDepth;
|
|
}
|
|
~RecursionDepthCounter() {
|
|
--deleteFunctionRecursionDepth;
|
|
}
|
|
|
|
size_t value() {
|
|
return deleteFunctionRecursionDepth;
|
|
}
|
|
};
|
|
|
|
/*
|
|
* Note that the custom deleter deletes in BFS style. Without using
|
|
* the custom deleter, the computation graph is deleted in a DFS style.
|
|
* The BFS deletion is valid (and safe) because if a shared_ptr<Function>
|
|
* 's reference count hits 0, nothing else will access it.
|
|
*/
|
|
void deleteFunction(Function* function) {
|
|
RecursionDepthCounter recursion_depth;
|
|
|
|
if (recursion_depth.value() > deleteFunctionMaxRecursionDepth) {
|
|
deleteFunctionQueue.push_back(function);
|
|
return;
|
|
}
|
|
|
|
delete function;
|
|
|
|
if (deleteFunctionQueue.empty()) {
|
|
return;
|
|
}
|
|
if (recursion_depth.value() != deleteFunctionMaxRecursionDepth) {
|
|
AT_ERROR("Only one deleter per thread should be able to process "
|
|
"the delete queue. Please open an issue.");
|
|
}
|
|
while (!deleteFunctionQueue.empty()) {
|
|
auto queued_function = deleteFunctionQueue.front();
|
|
deleteFunctionQueue.pop_front();
|
|
delete queued_function;
|
|
}
|
|
}
|
|
|
|
}} // namespace torch::autograd
|