Expose API for backward execution order (#87507)

In this PR:
- graph_task stores graph roots on construction so that we can later traverse through the graph
- before the nodes are returned, they needed to be converted from raw_ptr to shared_ptr, and this should be OK because the graph is guaranteed to be alive

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87507
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer 2022-10-26 13:34:34 -04:00 committed by PyTorch MergeBot
parent 926827b89c
commit adb76ef510
6 changed files with 245 additions and 11 deletions

View File

@ -3073,6 +3073,128 @@ class TestAutograd(TestCase):
self.assertEqual(torch._C._current_graph_task_id(), -1) self.assertEqual(torch._C._current_graph_task_id(), -1)
def test_current_graph_task_execution_order(self):
predicted = [None]
def hook(_):
predicted[0] = torch._C._current_graph_task_execution_order()
def names(nodes):
return ", ".join([node.name().split(' ')[-1] for node in nodes]) + '\n'
def grad_fns(*tensors):
# or grad accumulator
out = []
for t in tensors:
if t.requires_grad and t.grad_fn is None:
out.append(t.clone().grad_fn.next_functions[0][0])
else:
out.append(t.grad_fn)
return out
actual = []
def register_logging_hooks(*tensors):
# register hooks that log the order in which they are called
def get_hook(i):
def hook(t_):
actual.append(tensors[i])
return hook
for i, t in enumerate(tensors):
t.register_hook(get_hook(i))
# Basic example: single path
t = torch.tensor(1., requires_grad=True).clone().sin().exp()
t.register_hook(hook)
with torch.autograd.set_multithreading_enabled(False):
t.backward()
self.assertExpectedInline(names(predicted[0]), """\
ExpBackward0, SinBackward0, CloneBackward0, torch::autograd::AccumulateGrad
""")
# We don't exactly follow sequence_nr order
a = torch.tensor(1., requires_grad=True)
b = torch.tensor(2., requires_grad=True)
c = b.sin()
d = a.cos()
out = c * d
register_logging_hooks(a, b, c, d, out)
out.register_hook(hook)
with torch.autograd.set_multithreading_enabled(False):
out.backward()
self.assertEqual(predicted[0], grad_fns(*actual))
actual = []
# Multiple roots are also OK
a = torch.tensor(1., requires_grad=True)
b = a * 2
out = b.sin()
out2 = b.cos()
out3 = b.cos()
register_logging_hooks(a, b, out, out2, out3)
out3.register_hook(hook)
with torch.autograd.set_multithreading_enabled(False):
torch.autograd.grad((out, out3, out2), inputs=(a,))
self.assertExpectedInline(names(predicted[0]), """\
CosBackward0, CosBackward0, SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
""")
# TODO: Uncomment after update to hooks behavior
# self.assertEqual(predicted[0], grad_fns(*actual))
actual = []
# Case where next node is nullptr
a = torch.tensor(1., requires_grad=True)
b = a * 2
out = b.sin()
register_logging_hooks(a, b, out)
out.register_hook(hook)
with torch.autograd.set_multithreading_enabled(False):
out.backward()
self.assertEqual(predicted[0], grad_fns(*actual))
actual = []
# Case where two `inputs` on the same path
a = torch.tensor(1., requires_grad=True)
b = a * 2
out = b.sin()
register_logging_hooks(a, b, out)
out.register_hook(hook)
with torch.autograd.set_multithreading_enabled(False):
torch.autograd.grad((out,), inputs=(a, b,))
self.assertEqual(names(predicted[0]), """\
SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
""")
# TODO: Uncomment after update to hooks behavior
# self.assertEqual(predicted[0], grad_fns(*actual))
actual = []
# Case where `inputs` specifies a subgraph
a = torch.tensor(1., requires_grad=True)
b = torch.tensor(1., requires_grad=True)
c = a * b
out = c.sin()
register_logging_hooks(a, b, c, out)
out.register_hook(hook)
with torch.autograd.set_multithreading_enabled(False):
torch.autograd.grad((out,), inputs=(a,))
self.assertEqual(names(predicted[0]), """\
SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
""")
# TODO: Uncomment after update to hooks behavior
# self.assertEqual(predicted[0], grad_fns(*actual))
actual = []
# Errors when not called in a backward
with self.assertRaisesRegex(RuntimeError, "should only be called during the backward pass"):
torch._C._current_graph_task_execution_order()
# Errors when context manager not enabled
t = torch.tensor(1., requires_grad=True).clone().sin().exp()
t.register_hook(hook)
with self.assertRaisesRegex(RuntimeError, "expects the current backward to be executed with multithreading disabled"):
t.backward()
def test_profiler(self): def test_profiler(self):
x = torch.randn(10, 10) x = torch.randn(10, 10)

View File

@ -813,6 +813,28 @@ PyObject* THPModule_willEngineExecuteNode(PyObject* _unused, PyObject* arg) {
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
PyObject* THPModule_getCurrentGraphTaskExecutionOrder(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
std::vector<torch::autograd::Node*> nodes =
torch::autograd::get_current_graph_task_execution_order();
TORCH_CHECK(
nodes.size(),
"_current_graph_task_execution_order should only be called during the backward pass");
auto list = THPObjectPtr(PyList_New(nodes.size()));
if (!list)
return nullptr;
for (const auto i : c10::irange(nodes.size())) {
// This node is guaranteed to be alive since the backward is still running
PyObject* pyobj_node =
torch::autograd::functionToPyObject(nodes[i]->getptr());
PyList_SET_ITEM(list.get(), i, pyobj_node);
}
return list.release();
END_HANDLE_TH_ERRORS
}
PyObject* THPModule_getCurrentGraphTaskId(PyObject* _unused, PyObject* noargs) { PyObject* THPModule_getCurrentGraphTaskId(PyObject* _unused, PyObject* noargs) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
return THPUtils_packInt64(torch::autograd::get_current_graph_task_id()); return THPUtils_packInt64(torch::autograd::get_current_graph_task_id());
@ -1019,6 +1041,10 @@ static PyMethodDef TorchMethods[] = {
THPModule_willEngineExecuteNode, THPModule_willEngineExecuteNode,
METH_O, METH_O,
nullptr}, nullptr},
{"_current_graph_task_execution_order",
THPModule_getCurrentGraphTaskExecutionOrder,
METH_NOARGS,
nullptr},
{"_current_graph_task_id", {"_current_graph_task_id",
THPModule_getCurrentGraphTaskId, THPModule_getCurrentGraphTaskId,
METH_NOARGS, METH_NOARGS,

View File

@ -398,6 +398,66 @@ void add_node_to_current_graph_task_exec_info(Node* fn) {
current_graph_task->exec_info_[fn].needed_ = true; current_graph_task->exec_info_[fn].needed_ = true;
} }
// NB: The engine itself does not use the outputs of this function.
std::vector<Node*> get_current_graph_task_execution_order() {
std::shared_ptr<GraphTask> task = current_graph_task;
if (!task) {
return {};
}
// We could potentially check if there is only a single device here
// but explicitly require this context doens't seem bad either
TORCH_CHECK(
!c10::AutogradState::get_tls_state().get_multithreading_enabled(),
"get_current_graph_task_execution_order expects the current backward to be "
"executed with multithreading disabled, e.g. by running:\n\n"
">>> with torch.autograd.set_multithreading_enabled(False):\n"
"... torch.autograd.grad(...)\n");
const bool check_exec_info = !task->exec_info_.empty();
std::vector<Node*> out{};
std::unordered_set<Node*> seen{};
auto compare_seq_nr = [](Node* n1, Node* n2) {
return n1->sequence_nr() < n2->sequence_nr();
};
std::priority_queue<Node*, std::vector<Node*>, decltype(compare_seq_nr)> heap(
compare_seq_nr);
for (Node* ptr : task->graph_roots_) {
heap.push(ptr);
}
// Implementation notes:
// - Don't need to count dependencies because we have sequence_nr
// - Don't need to check topological_nr because we have exec_info
while (!heap.empty()) {
Node* fn = heap.top();
heap.pop();
const bool was_inserted = seen.insert(fn).second;
if (!was_inserted) {
continue;
}
out.push_back(fn);
for (const auto& edge : fn->next_edges()) {
Node* next_ptr = edge.function.get();
if (!next_ptr) {
continue;
}
if (check_exec_info) {
auto it = task->exec_info_.find(next_ptr);
if (it == task->exec_info_.end() || !it->second.should_execute()) {
continue;
}
}
heap.push(next_ptr);
}
}
return out;
}
// NOTE: graph_tasks do not necessarily form a stack. Imagine this // NOTE: graph_tasks do not necessarily form a stack. Imagine this
// case: // case:
// //
@ -1050,7 +1110,7 @@ auto Engine::compute_dependencies(
} }
auto Engine::execute( auto Engine::execute(
const edge_list& roots, const edge_list& root_edges,
const variable_list& inputs, const variable_list& inputs,
bool keep_graph, bool keep_graph,
bool create_graph, bool create_graph,
@ -1058,9 +1118,9 @@ auto Engine::execute(
const edge_list& outputs) -> variable_list { const edge_list& outputs) -> variable_list {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
validate_outputs( validate_outputs(
roots, const_cast<variable_list&>(inputs), [](const std::string& msg) { root_edges,
return msg; const_cast<variable_list&>(inputs),
}); [](const std::string& msg) { return msg; });
if (accumulate_grad && create_graph) { if (accumulate_grad && create_graph) {
TORCH_WARN_ONCE( TORCH_WARN_ONCE(
"Using backward() with create_graph=True will create a reference cycle " "Using backward() with create_graph=True will create a reference cycle "
@ -1083,17 +1143,25 @@ auto Engine::execute(
init_local_ready_queue(); init_local_ready_queue();
bool not_reentrant_backward_call = worker_device == NO_DEVICE; bool not_reentrant_backward_call = worker_device == NO_DEVICE;
// Store root nodes so we can traverse through the graph later
// e.g., for get_current_graph_task_execution_order
c10::SmallVector<Node*, 4> temp_roots{root_edges.size()};
for (const auto i : c10::irange(root_edges.size())) {
temp_roots[i] = root_edges[i].function.get();
}
auto graph_task = std::make_shared<GraphTask>( auto graph_task = std::make_shared<GraphTask>(
/* keep_graph */ keep_graph, /* keep_graph */ keep_graph,
/* create_graph */ create_graph, /* create_graph */ create_graph,
/* depth */ not_reentrant_backward_call ? 0 : total_depth + 1, /* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
/* cpu_ready_queue */ local_ready_queue); /* cpu_ready_queue */ local_ready_queue,
/* graph_roots */ std::move(temp_roots));
// If we receive a single root, skip creating extra root node // If we receive a single root, skip creating extra root node
bool skip_dummy_node = roots.size() == 1; bool skip_dummy_node = root_edges.size() == 1;
auto graph_root = skip_dummy_node auto graph_root = skip_dummy_node
? roots.at(0).function ? root_edges.at(0).function
: std::make_shared<GraphRoot>(roots, inputs); : std::make_shared<GraphRoot>(root_edges, inputs);
auto min_topo_nr = compute_min_topological_nr(outputs); auto min_topo_nr = compute_min_topological_nr(outputs);
// Now compute the dependencies for all executable functions // Now compute the dependencies for all executable functions
@ -1106,14 +1174,17 @@ auto Engine::execute(
// Queue the root // Queue the root
if (skip_dummy_node) { if (skip_dummy_node) {
InputBuffer input_buffer(roots.at(0).function->num_inputs()); InputBuffer input_buffer(root_edges.at(0).function->num_inputs());
auto input = inputs.at(0); auto input = inputs.at(0);
const auto input_stream = InputMetadata(input).stream(); const auto input_stream = InputMetadata(input).stream();
const auto opt_next_stream = const auto opt_next_stream =
roots.at(0).function->stream(c10::DeviceType::CUDA); root_edges.at(0).function->stream(c10::DeviceType::CUDA);
input_buffer.add( input_buffer.add(
roots.at(0).input_nr, std::move(input), input_stream, opt_next_stream); root_edges.at(0).input_nr,
std::move(input),
input_stream,
opt_next_stream);
execute_with_graph_task(graph_task, graph_root, std::move(input_buffer)); execute_with_graph_task(graph_task, graph_root, std::move(input_buffer));
} else { } else {

View File

@ -143,6 +143,9 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
Node& operator=(Node&& other) = delete; Node& operator=(Node&& other) = delete;
virtual ~Node() = default; virtual ~Node() = default;
std::shared_ptr<Node> getptr() {
return shared_from_this();
}
/// Evaluates the function on the given inputs and returns the result of the /// Evaluates the function on the given inputs and returns the result of the
/// function call. /// function call.
variable_list operator()(variable_list&& inputs) { variable_list operator()(variable_list&& inputs) {

View File

@ -37,6 +37,7 @@ struct GraphTask : std::enable_shared_from_this<GraphTask> {
// Records the nodes that are in the graph // Records the nodes that are in the graph
std::unordered_set<Node*> nodes_in_graph_; std::unordered_set<Node*> nodes_in_graph_;
c10::SmallVector<Node*, 4> graph_roots_;
// Note [Exec info] // Note [Exec info]
// Exec info is created for each GraphTask, which allows filtering paths on // Exec info is created for each GraphTask, which allows filtering paths on
// the graph that are not needed. It has a bit complicated semantics. If it's // the graph that are not needed. It has a bit complicated semantics. If it's
@ -164,8 +165,10 @@ struct GraphTask : std::enable_shared_from_this<GraphTask> {
bool grad_mode, bool grad_mode,
int reentrant_depth, int reentrant_depth,
std::shared_ptr<ReadyQueue> cpu_ready_queue, std::shared_ptr<ReadyQueue> cpu_ready_queue,
c10::SmallVector<Node*, 4> graph_roots,
bool exit_on_error = false) bool exit_on_error = false)
: keep_graph_(keep_graph), : keep_graph_(keep_graph),
graph_roots_(std::move(graph_roots)),
owner_(NO_DEVICE), owner_(NO_DEVICE),
reentrant_depth_(reentrant_depth), reentrant_depth_(reentrant_depth),
exit_on_error_(exit_on_error), exit_on_error_(exit_on_error),
@ -198,6 +201,7 @@ get_current_graph_task_exec_info();
TORCH_API const std::unordered_set<Node*>* TORCH_API const std::unordered_set<Node*>*
get_current_graph_task_nodes_in_graph(); get_current_graph_task_nodes_in_graph();
TORCH_API bool get_current_graph_task_keep_graph(); TORCH_API bool get_current_graph_task_keep_graph();
TORCH_API std::vector<Node*> get_current_graph_task_execution_order();
TORCH_API int get_current_graph_task_id(); TORCH_API int get_current_graph_task_id();
void add_node_to_current_graph_task_exec_info(Node* fn); void add_node_to_current_graph_task_exec_info(Node* fn);

View File

@ -185,6 +185,13 @@ void DistEngine::computeDependencies(
bool retainGraph) { bool retainGraph) {
TORCH_INTERNAL_ASSERT(graphRoot, "graphRoot is null!"); TORCH_INTERNAL_ASSERT(graphRoot, "graphRoot is null!");
// Store root nodes so we can traverse through the graph later
// e.g., for get_current_graph_task_execution_order
c10::SmallVector<Node*, 4> temp_roots{rootEdges.size()};
for (const auto i : c10::irange(rootEdges.size())) {
temp_roots[i] = rootEdges[i].function.get();
}
// Build the graph task and graph root. // Build the graph task and graph root.
// NOTE: we don't need to build and pass a cpu_ready_queue to GraphTask // NOTE: we don't need to build and pass a cpu_ready_queue to GraphTask
// as we use execute_graph_task_until_ready_queue_empty, which will build // as we use execute_graph_task_until_ready_queue_empty, which will build
@ -194,6 +201,7 @@ void DistEngine::computeDependencies(
/* create_graph */ false, /* create_graph */ false,
/* depth */ 0, /* depth */ 0,
/* cpu_ready_queue */ global_cpu_ready_queue_, /* cpu_ready_queue */ global_cpu_ready_queue_,
/* graph_roots */ temp_roots,
/* exit_on_error */ true); /* exit_on_error */ true);
// Run BFS to traverse the graph locally. The roots of the graph are // Run BFS to traverse the graph locally. The roots of the graph are