mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR adds a private runtime feature flag for the feature work we're going to do with extending autograd.Function. The motivation of the feature flag is: - to guard the feature against unsuspecting users - control the release of the feature to when we are ready to release it We might not even need the feature flag (because we hope to have the work done in the next month), but it is good practice and it does touch currently public API (autograd.Function). Concretely, "autograd.Function extension" refers to: - adding an optional `setup_context` staticmethod to autograd.Function - adding an optional `vmap` staticmethod to autograd.Function - autograd.Function support for functorch Test Plan: - new test that the feature flag works Pull Request resolved: https://github.com/pytorch/pytorch/pull/89858 Approved by: https://github.com/soulitzer
116 lines
3.4 KiB
C++
116 lines
3.4 KiB
C++
#include <torch/csrc/autograd/function.h>
|
|
|
|
#include <c10/util/ThreadLocal.h>
|
|
#include <torch/csrc/autograd/engine.h>
|
|
#include <torch/csrc/autograd/variable.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
|
|
#include <algorithm>
|
|
#include <cstdint>
|
|
#include <memory>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace autograd {
|
|
|
|
// The current evaluating node. This is useful to assign the current node as a
|
|
// parent of new nodes created during the evaluation of this node in anomaly
|
|
// mode.
|
|
C10_DEFINE_TLS_static(std::shared_ptr<Node>, tls_current_evaluating_node);
|
|
#define current_evaluating_node (tls_current_evaluating_node.get())
|
|
|
|
NodeGuard::NodeGuard(std::shared_ptr<Node> node) {
|
|
last_evaluating_node_ = std::move(current_evaluating_node);
|
|
current_evaluating_node = std::move(node);
|
|
}
|
|
NodeGuard::~NodeGuard() {
|
|
// restore the previous evaluating node
|
|
current_evaluating_node = std::move(last_evaluating_node_);
|
|
}
|
|
|
|
void Node::assign_parent() {
|
|
metadata()->assign_parent(current_evaluating_node);
|
|
}
|
|
|
|
auto Node::name() const -> std::string {
|
|
return c10::demangle(typeid(*this).name());
|
|
}
|
|
|
|
AnomalyMetadata* Node::metadata() noexcept {
|
|
if (!anomaly_metadata_) {
|
|
anomaly_metadata_ = Engine::get_default_engine().make_anomaly_metadata();
|
|
}
|
|
return anomaly_metadata_.get();
|
|
}
|
|
|
|
static void gatherFunctions(
|
|
Node* func,
|
|
std::vector<std::shared_ptr<Node>>& stack) {
|
|
func->release_variables();
|
|
|
|
for (auto& edge : func->next_edges()) {
|
|
if (edge.function.use_count() == 1) {
|
|
stack.emplace_back(std::move(edge.function));
|
|
} else {
|
|
edge.function.reset();
|
|
}
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Fix for #5534: prevent stack overflow on deletion of deep computation graph
|
|
*
|
|
* Sometimes one can end up with a very big computation graph of Nodes
|
|
* and Edges. Each std::shared_ptr<Node> contains a list of Edge, and
|
|
* each Edge contains a std::shared_ptr<Node>. Deleting a
|
|
* std::shared_ptr<Node> can trigger the recursive deletion of other
|
|
* std::shared_ptr<Node>'s: this can stack overflow if the graph
|
|
* is deep enough. Here is an example of such a graph:
|
|
*
|
|
* shared_ptr<Node> -> Edge -> shared_ptr<Node> -> Edge -> ... ->
|
|
* shared_ptr<Node>
|
|
*
|
|
* The solution here is to detect when we are decrementing away the last
|
|
* reference to a Node, and when doing so to buffer up the Node's
|
|
* that will be recursively decremented. We can then decrement (and free)
|
|
* the original Node without causing a recursive cascade, before
|
|
* draining the buffer applying the same behavior. This is, in effect,
|
|
* converting recursion to a loop, using a heap buffer in place of the
|
|
* recursive call stack.
|
|
*/
|
|
void deleteNode(Node* function) {
|
|
// To avoid stack overflow on large computational graphs,
|
|
// we need to track reference decrementing and freeing
|
|
// on the heap.
|
|
function->release_variables();
|
|
std::vector<std::shared_ptr<Node>> stack;
|
|
gatherFunctions(function, stack);
|
|
delete function;
|
|
|
|
while (!stack.empty()) {
|
|
auto func = std::move(stack.back());
|
|
stack.pop_back();
|
|
gatherFunctions(func.get(), stack);
|
|
// Reference count is decremented on the loop backedge.
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
bool kAutogradFunctionExtensionEnabled = false;
|
|
}
|
|
|
|
bool isAutogradFunctionExtensionEnabled() {
|
|
return kAutogradFunctionExtensionEnabled;
|
|
}
|
|
|
|
void setAutogradFunctionExtensionEnabled(bool enabled) {
|
|
kAutogradFunctionExtensionEnabled = enabled;
|
|
}
|
|
|
|
} // namespace autograd
|
|
} // namespace torch
|