mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Expose API for backward execution order (#87507)
In this PR: - graph_task stores graph roots on construction so that we can later traverse through the graph - before the nodes are returned, they needed to be converted from raw_ptr to shared_ptr, and this should be OK because the graph is guaranteed to be alive Pull Request resolved: https://github.com/pytorch/pytorch/pull/87507 Approved by: https://github.com/albanD
This commit is contained in:
parent
926827b89c
commit
adb76ef510
|
|
@ -3073,6 +3073,128 @@ class TestAutograd(TestCase):
|
|||
|
||||
self.assertEqual(torch._C._current_graph_task_id(), -1)
|
||||
|
||||
def test_current_graph_task_execution_order(self):
|
||||
predicted = [None]
|
||||
|
||||
def hook(_):
|
||||
predicted[0] = torch._C._current_graph_task_execution_order()
|
||||
|
||||
def names(nodes):
|
||||
return ", ".join([node.name().split(' ')[-1] for node in nodes]) + '\n'
|
||||
|
||||
def grad_fns(*tensors):
|
||||
# or grad accumulator
|
||||
out = []
|
||||
for t in tensors:
|
||||
if t.requires_grad and t.grad_fn is None:
|
||||
out.append(t.clone().grad_fn.next_functions[0][0])
|
||||
else:
|
||||
out.append(t.grad_fn)
|
||||
return out
|
||||
|
||||
actual = []
|
||||
|
||||
def register_logging_hooks(*tensors):
|
||||
# register hooks that log the order in which they are called
|
||||
def get_hook(i):
|
||||
def hook(t_):
|
||||
actual.append(tensors[i])
|
||||
return hook
|
||||
|
||||
for i, t in enumerate(tensors):
|
||||
t.register_hook(get_hook(i))
|
||||
|
||||
# Basic example: single path
|
||||
t = torch.tensor(1., requires_grad=True).clone().sin().exp()
|
||||
t.register_hook(hook)
|
||||
with torch.autograd.set_multithreading_enabled(False):
|
||||
t.backward()
|
||||
self.assertExpectedInline(names(predicted[0]), """\
|
||||
ExpBackward0, SinBackward0, CloneBackward0, torch::autograd::AccumulateGrad
|
||||
""")
|
||||
|
||||
# We don't exactly follow sequence_nr order
|
||||
a = torch.tensor(1., requires_grad=True)
|
||||
b = torch.tensor(2., requires_grad=True)
|
||||
c = b.sin()
|
||||
d = a.cos()
|
||||
out = c * d
|
||||
register_logging_hooks(a, b, c, d, out)
|
||||
out.register_hook(hook)
|
||||
with torch.autograd.set_multithreading_enabled(False):
|
||||
out.backward()
|
||||
self.assertEqual(predicted[0], grad_fns(*actual))
|
||||
actual = []
|
||||
|
||||
# Multiple roots are also OK
|
||||
a = torch.tensor(1., requires_grad=True)
|
||||
b = a * 2
|
||||
out = b.sin()
|
||||
out2 = b.cos()
|
||||
out3 = b.cos()
|
||||
register_logging_hooks(a, b, out, out2, out3)
|
||||
out3.register_hook(hook)
|
||||
with torch.autograd.set_multithreading_enabled(False):
|
||||
torch.autograd.grad((out, out3, out2), inputs=(a,))
|
||||
self.assertExpectedInline(names(predicted[0]), """\
|
||||
CosBackward0, CosBackward0, SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
|
||||
""")
|
||||
# TODO: Uncomment after update to hooks behavior
|
||||
# self.assertEqual(predicted[0], grad_fns(*actual))
|
||||
actual = []
|
||||
|
||||
# Case where next node is nullptr
|
||||
a = torch.tensor(1., requires_grad=True)
|
||||
b = a * 2
|
||||
out = b.sin()
|
||||
register_logging_hooks(a, b, out)
|
||||
out.register_hook(hook)
|
||||
with torch.autograd.set_multithreading_enabled(False):
|
||||
out.backward()
|
||||
self.assertEqual(predicted[0], grad_fns(*actual))
|
||||
actual = []
|
||||
|
||||
# Case where two `inputs` on the same path
|
||||
a = torch.tensor(1., requires_grad=True)
|
||||
b = a * 2
|
||||
out = b.sin()
|
||||
register_logging_hooks(a, b, out)
|
||||
out.register_hook(hook)
|
||||
with torch.autograd.set_multithreading_enabled(False):
|
||||
torch.autograd.grad((out,), inputs=(a, b,))
|
||||
self.assertEqual(names(predicted[0]), """\
|
||||
SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
|
||||
""")
|
||||
# TODO: Uncomment after update to hooks behavior
|
||||
# self.assertEqual(predicted[0], grad_fns(*actual))
|
||||
actual = []
|
||||
|
||||
# Case where `inputs` specifies a subgraph
|
||||
a = torch.tensor(1., requires_grad=True)
|
||||
b = torch.tensor(1., requires_grad=True)
|
||||
c = a * b
|
||||
out = c.sin()
|
||||
register_logging_hooks(a, b, c, out)
|
||||
out.register_hook(hook)
|
||||
with torch.autograd.set_multithreading_enabled(False):
|
||||
torch.autograd.grad((out,), inputs=(a,))
|
||||
self.assertEqual(names(predicted[0]), """\
|
||||
SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
|
||||
""")
|
||||
# TODO: Uncomment after update to hooks behavior
|
||||
# self.assertEqual(predicted[0], grad_fns(*actual))
|
||||
actual = []
|
||||
|
||||
# Errors when not called in a backward
|
||||
with self.assertRaisesRegex(RuntimeError, "should only be called during the backward pass"):
|
||||
torch._C._current_graph_task_execution_order()
|
||||
|
||||
# Errors when context manager not enabled
|
||||
t = torch.tensor(1., requires_grad=True).clone().sin().exp()
|
||||
t.register_hook(hook)
|
||||
with self.assertRaisesRegex(RuntimeError, "expects the current backward to be executed with multithreading disabled"):
|
||||
t.backward()
|
||||
|
||||
def test_profiler(self):
|
||||
x = torch.randn(10, 10)
|
||||
|
||||
|
|
|
|||
|
|
@ -813,6 +813,28 @@ PyObject* THPModule_willEngineExecuteNode(PyObject* _unused, PyObject* arg) {
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPModule_getCurrentGraphTaskExecutionOrder(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
std::vector<torch::autograd::Node*> nodes =
|
||||
torch::autograd::get_current_graph_task_execution_order();
|
||||
TORCH_CHECK(
|
||||
nodes.size(),
|
||||
"_current_graph_task_execution_order should only be called during the backward pass");
|
||||
auto list = THPObjectPtr(PyList_New(nodes.size()));
|
||||
if (!list)
|
||||
return nullptr;
|
||||
for (const auto i : c10::irange(nodes.size())) {
|
||||
// This node is guaranteed to be alive since the backward is still running
|
||||
PyObject* pyobj_node =
|
||||
torch::autograd::functionToPyObject(nodes[i]->getptr());
|
||||
PyList_SET_ITEM(list.get(), i, pyobj_node);
|
||||
}
|
||||
return list.release();
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPModule_getCurrentGraphTaskId(PyObject* _unused, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
return THPUtils_packInt64(torch::autograd::get_current_graph_task_id());
|
||||
|
|
@ -1019,6 +1041,10 @@ static PyMethodDef TorchMethods[] = {
|
|||
THPModule_willEngineExecuteNode,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_current_graph_task_execution_order",
|
||||
THPModule_getCurrentGraphTaskExecutionOrder,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_current_graph_task_id",
|
||||
THPModule_getCurrentGraphTaskId,
|
||||
METH_NOARGS,
|
||||
|
|
|
|||
|
|
@ -398,6 +398,66 @@ void add_node_to_current_graph_task_exec_info(Node* fn) {
|
|||
current_graph_task->exec_info_[fn].needed_ = true;
|
||||
}
|
||||
|
||||
// NB: The engine itself does not use the outputs of this function.
|
||||
std::vector<Node*> get_current_graph_task_execution_order() {
|
||||
std::shared_ptr<GraphTask> task = current_graph_task;
|
||||
if (!task) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// We could potentially check if there is only a single device here
|
||||
// but explicitly require this context doens't seem bad either
|
||||
TORCH_CHECK(
|
||||
!c10::AutogradState::get_tls_state().get_multithreading_enabled(),
|
||||
"get_current_graph_task_execution_order expects the current backward to be "
|
||||
"executed with multithreading disabled, e.g. by running:\n\n"
|
||||
">>> with torch.autograd.set_multithreading_enabled(False):\n"
|
||||
"... torch.autograd.grad(...)\n");
|
||||
|
||||
const bool check_exec_info = !task->exec_info_.empty();
|
||||
std::vector<Node*> out{};
|
||||
std::unordered_set<Node*> seen{};
|
||||
|
||||
auto compare_seq_nr = [](Node* n1, Node* n2) {
|
||||
return n1->sequence_nr() < n2->sequence_nr();
|
||||
};
|
||||
std::priority_queue<Node*, std::vector<Node*>, decltype(compare_seq_nr)> heap(
|
||||
compare_seq_nr);
|
||||
|
||||
for (Node* ptr : task->graph_roots_) {
|
||||
heap.push(ptr);
|
||||
}
|
||||
|
||||
// Implementation notes:
|
||||
// - Don't need to count dependencies because we have sequence_nr
|
||||
// - Don't need to check topological_nr because we have exec_info
|
||||
while (!heap.empty()) {
|
||||
Node* fn = heap.top();
|
||||
heap.pop();
|
||||
|
||||
const bool was_inserted = seen.insert(fn).second;
|
||||
if (!was_inserted) {
|
||||
continue;
|
||||
}
|
||||
|
||||
out.push_back(fn);
|
||||
for (const auto& edge : fn->next_edges()) {
|
||||
Node* next_ptr = edge.function.get();
|
||||
if (!next_ptr) {
|
||||
continue;
|
||||
}
|
||||
if (check_exec_info) {
|
||||
auto it = task->exec_info_.find(next_ptr);
|
||||
if (it == task->exec_info_.end() || !it->second.should_execute()) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
heap.push(next_ptr);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// NOTE: graph_tasks do not necessarily form a stack. Imagine this
|
||||
// case:
|
||||
//
|
||||
|
|
@ -1050,7 +1110,7 @@ auto Engine::compute_dependencies(
|
|||
}
|
||||
|
||||
auto Engine::execute(
|
||||
const edge_list& roots,
|
||||
const edge_list& root_edges,
|
||||
const variable_list& inputs,
|
||||
bool keep_graph,
|
||||
bool create_graph,
|
||||
|
|
@ -1058,9 +1118,9 @@ auto Engine::execute(
|
|||
const edge_list& outputs) -> variable_list {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
validate_outputs(
|
||||
roots, const_cast<variable_list&>(inputs), [](const std::string& msg) {
|
||||
return msg;
|
||||
});
|
||||
root_edges,
|
||||
const_cast<variable_list&>(inputs),
|
||||
[](const std::string& msg) { return msg; });
|
||||
if (accumulate_grad && create_graph) {
|
||||
TORCH_WARN_ONCE(
|
||||
"Using backward() with create_graph=True will create a reference cycle "
|
||||
|
|
@ -1083,17 +1143,25 @@ auto Engine::execute(
|
|||
init_local_ready_queue();
|
||||
bool not_reentrant_backward_call = worker_device == NO_DEVICE;
|
||||
|
||||
// Store root nodes so we can traverse through the graph later
|
||||
// e.g., for get_current_graph_task_execution_order
|
||||
c10::SmallVector<Node*, 4> temp_roots{root_edges.size()};
|
||||
for (const auto i : c10::irange(root_edges.size())) {
|
||||
temp_roots[i] = root_edges[i].function.get();
|
||||
}
|
||||
|
||||
auto graph_task = std::make_shared<GraphTask>(
|
||||
/* keep_graph */ keep_graph,
|
||||
/* create_graph */ create_graph,
|
||||
/* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
|
||||
/* cpu_ready_queue */ local_ready_queue);
|
||||
/* cpu_ready_queue */ local_ready_queue,
|
||||
/* graph_roots */ std::move(temp_roots));
|
||||
|
||||
// If we receive a single root, skip creating extra root node
|
||||
bool skip_dummy_node = roots.size() == 1;
|
||||
bool skip_dummy_node = root_edges.size() == 1;
|
||||
auto graph_root = skip_dummy_node
|
||||
? roots.at(0).function
|
||||
: std::make_shared<GraphRoot>(roots, inputs);
|
||||
? root_edges.at(0).function
|
||||
: std::make_shared<GraphRoot>(root_edges, inputs);
|
||||
|
||||
auto min_topo_nr = compute_min_topological_nr(outputs);
|
||||
// Now compute the dependencies for all executable functions
|
||||
|
|
@ -1106,14 +1174,17 @@ auto Engine::execute(
|
|||
|
||||
// Queue the root
|
||||
if (skip_dummy_node) {
|
||||
InputBuffer input_buffer(roots.at(0).function->num_inputs());
|
||||
InputBuffer input_buffer(root_edges.at(0).function->num_inputs());
|
||||
auto input = inputs.at(0);
|
||||
|
||||
const auto input_stream = InputMetadata(input).stream();
|
||||
const auto opt_next_stream =
|
||||
roots.at(0).function->stream(c10::DeviceType::CUDA);
|
||||
root_edges.at(0).function->stream(c10::DeviceType::CUDA);
|
||||
input_buffer.add(
|
||||
roots.at(0).input_nr, std::move(input), input_stream, opt_next_stream);
|
||||
root_edges.at(0).input_nr,
|
||||
std::move(input),
|
||||
input_stream,
|
||||
opt_next_stream);
|
||||
|
||||
execute_with_graph_task(graph_task, graph_root, std::move(input_buffer));
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -143,6 +143,9 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
|||
Node& operator=(Node&& other) = delete;
|
||||
virtual ~Node() = default;
|
||||
|
||||
std::shared_ptr<Node> getptr() {
|
||||
return shared_from_this();
|
||||
}
|
||||
/// Evaluates the function on the given inputs and returns the result of the
|
||||
/// function call.
|
||||
variable_list operator()(variable_list&& inputs) {
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ struct GraphTask : std::enable_shared_from_this<GraphTask> {
|
|||
|
||||
// Records the nodes that are in the graph
|
||||
std::unordered_set<Node*> nodes_in_graph_;
|
||||
c10::SmallVector<Node*, 4> graph_roots_;
|
||||
// Note [Exec info]
|
||||
// Exec info is created for each GraphTask, which allows filtering paths on
|
||||
// the graph that are not needed. It has a bit complicated semantics. If it's
|
||||
|
|
@ -164,8 +165,10 @@ struct GraphTask : std::enable_shared_from_this<GraphTask> {
|
|||
bool grad_mode,
|
||||
int reentrant_depth,
|
||||
std::shared_ptr<ReadyQueue> cpu_ready_queue,
|
||||
c10::SmallVector<Node*, 4> graph_roots,
|
||||
bool exit_on_error = false)
|
||||
: keep_graph_(keep_graph),
|
||||
graph_roots_(std::move(graph_roots)),
|
||||
owner_(NO_DEVICE),
|
||||
reentrant_depth_(reentrant_depth),
|
||||
exit_on_error_(exit_on_error),
|
||||
|
|
@ -198,6 +201,7 @@ get_current_graph_task_exec_info();
|
|||
TORCH_API const std::unordered_set<Node*>*
|
||||
get_current_graph_task_nodes_in_graph();
|
||||
TORCH_API bool get_current_graph_task_keep_graph();
|
||||
TORCH_API std::vector<Node*> get_current_graph_task_execution_order();
|
||||
TORCH_API int get_current_graph_task_id();
|
||||
void add_node_to_current_graph_task_exec_info(Node* fn);
|
||||
|
||||
|
|
|
|||
|
|
@ -185,6 +185,13 @@ void DistEngine::computeDependencies(
|
|||
bool retainGraph) {
|
||||
TORCH_INTERNAL_ASSERT(graphRoot, "graphRoot is null!");
|
||||
|
||||
// Store root nodes so we can traverse through the graph later
|
||||
// e.g., for get_current_graph_task_execution_order
|
||||
c10::SmallVector<Node*, 4> temp_roots{rootEdges.size()};
|
||||
for (const auto i : c10::irange(rootEdges.size())) {
|
||||
temp_roots[i] = rootEdges[i].function.get();
|
||||
}
|
||||
|
||||
// Build the graph task and graph root.
|
||||
// NOTE: we don't need to build and pass a cpu_ready_queue to GraphTask
|
||||
// as we use execute_graph_task_until_ready_queue_empty, which will build
|
||||
|
|
@ -194,6 +201,7 @@ void DistEngine::computeDependencies(
|
|||
/* create_graph */ false,
|
||||
/* depth */ 0,
|
||||
/* cpu_ready_queue */ global_cpu_ready_queue_,
|
||||
/* graph_roots */ temp_roots,
|
||||
/* exit_on_error */ true);
|
||||
|
||||
// Run BFS to traverse the graph locally. The roots of the graph are
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user