mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Expose API for backward execution order (#87507)
In this PR: - graph_task stores graph roots on construction so that we can later traverse through the graph - before the nodes are returned, they needed to be converted from raw_ptr to shared_ptr, and this should be OK because the graph is guaranteed to be alive Pull Request resolved: https://github.com/pytorch/pytorch/pull/87507 Approved by: https://github.com/albanD
This commit is contained in:
parent
926827b89c
commit
adb76ef510
|
|
@ -3073,6 +3073,128 @@ class TestAutograd(TestCase):
|
||||||
|
|
||||||
self.assertEqual(torch._C._current_graph_task_id(), -1)
|
self.assertEqual(torch._C._current_graph_task_id(), -1)
|
||||||
|
|
||||||
|
def test_current_graph_task_execution_order(self):
|
||||||
|
predicted = [None]
|
||||||
|
|
||||||
|
def hook(_):
|
||||||
|
predicted[0] = torch._C._current_graph_task_execution_order()
|
||||||
|
|
||||||
|
def names(nodes):
|
||||||
|
return ", ".join([node.name().split(' ')[-1] for node in nodes]) + '\n'
|
||||||
|
|
||||||
|
def grad_fns(*tensors):
|
||||||
|
# or grad accumulator
|
||||||
|
out = []
|
||||||
|
for t in tensors:
|
||||||
|
if t.requires_grad and t.grad_fn is None:
|
||||||
|
out.append(t.clone().grad_fn.next_functions[0][0])
|
||||||
|
else:
|
||||||
|
out.append(t.grad_fn)
|
||||||
|
return out
|
||||||
|
|
||||||
|
actual = []
|
||||||
|
|
||||||
|
def register_logging_hooks(*tensors):
|
||||||
|
# register hooks that log the order in which they are called
|
||||||
|
def get_hook(i):
|
||||||
|
def hook(t_):
|
||||||
|
actual.append(tensors[i])
|
||||||
|
return hook
|
||||||
|
|
||||||
|
for i, t in enumerate(tensors):
|
||||||
|
t.register_hook(get_hook(i))
|
||||||
|
|
||||||
|
# Basic example: single path
|
||||||
|
t = torch.tensor(1., requires_grad=True).clone().sin().exp()
|
||||||
|
t.register_hook(hook)
|
||||||
|
with torch.autograd.set_multithreading_enabled(False):
|
||||||
|
t.backward()
|
||||||
|
self.assertExpectedInline(names(predicted[0]), """\
|
||||||
|
ExpBackward0, SinBackward0, CloneBackward0, torch::autograd::AccumulateGrad
|
||||||
|
""")
|
||||||
|
|
||||||
|
# We don't exactly follow sequence_nr order
|
||||||
|
a = torch.tensor(1., requires_grad=True)
|
||||||
|
b = torch.tensor(2., requires_grad=True)
|
||||||
|
c = b.sin()
|
||||||
|
d = a.cos()
|
||||||
|
out = c * d
|
||||||
|
register_logging_hooks(a, b, c, d, out)
|
||||||
|
out.register_hook(hook)
|
||||||
|
with torch.autograd.set_multithreading_enabled(False):
|
||||||
|
out.backward()
|
||||||
|
self.assertEqual(predicted[0], grad_fns(*actual))
|
||||||
|
actual = []
|
||||||
|
|
||||||
|
# Multiple roots are also OK
|
||||||
|
a = torch.tensor(1., requires_grad=True)
|
||||||
|
b = a * 2
|
||||||
|
out = b.sin()
|
||||||
|
out2 = b.cos()
|
||||||
|
out3 = b.cos()
|
||||||
|
register_logging_hooks(a, b, out, out2, out3)
|
||||||
|
out3.register_hook(hook)
|
||||||
|
with torch.autograd.set_multithreading_enabled(False):
|
||||||
|
torch.autograd.grad((out, out3, out2), inputs=(a,))
|
||||||
|
self.assertExpectedInline(names(predicted[0]), """\
|
||||||
|
CosBackward0, CosBackward0, SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
|
||||||
|
""")
|
||||||
|
# TODO: Uncomment after update to hooks behavior
|
||||||
|
# self.assertEqual(predicted[0], grad_fns(*actual))
|
||||||
|
actual = []
|
||||||
|
|
||||||
|
# Case where next node is nullptr
|
||||||
|
a = torch.tensor(1., requires_grad=True)
|
||||||
|
b = a * 2
|
||||||
|
out = b.sin()
|
||||||
|
register_logging_hooks(a, b, out)
|
||||||
|
out.register_hook(hook)
|
||||||
|
with torch.autograd.set_multithreading_enabled(False):
|
||||||
|
out.backward()
|
||||||
|
self.assertEqual(predicted[0], grad_fns(*actual))
|
||||||
|
actual = []
|
||||||
|
|
||||||
|
# Case where two `inputs` on the same path
|
||||||
|
a = torch.tensor(1., requires_grad=True)
|
||||||
|
b = a * 2
|
||||||
|
out = b.sin()
|
||||||
|
register_logging_hooks(a, b, out)
|
||||||
|
out.register_hook(hook)
|
||||||
|
with torch.autograd.set_multithreading_enabled(False):
|
||||||
|
torch.autograd.grad((out,), inputs=(a, b,))
|
||||||
|
self.assertEqual(names(predicted[0]), """\
|
||||||
|
SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
|
||||||
|
""")
|
||||||
|
# TODO: Uncomment after update to hooks behavior
|
||||||
|
# self.assertEqual(predicted[0], grad_fns(*actual))
|
||||||
|
actual = []
|
||||||
|
|
||||||
|
# Case where `inputs` specifies a subgraph
|
||||||
|
a = torch.tensor(1., requires_grad=True)
|
||||||
|
b = torch.tensor(1., requires_grad=True)
|
||||||
|
c = a * b
|
||||||
|
out = c.sin()
|
||||||
|
register_logging_hooks(a, b, c, out)
|
||||||
|
out.register_hook(hook)
|
||||||
|
with torch.autograd.set_multithreading_enabled(False):
|
||||||
|
torch.autograd.grad((out,), inputs=(a,))
|
||||||
|
self.assertEqual(names(predicted[0]), """\
|
||||||
|
SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
|
||||||
|
""")
|
||||||
|
# TODO: Uncomment after update to hooks behavior
|
||||||
|
# self.assertEqual(predicted[0], grad_fns(*actual))
|
||||||
|
actual = []
|
||||||
|
|
||||||
|
# Errors when not called in a backward
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "should only be called during the backward pass"):
|
||||||
|
torch._C._current_graph_task_execution_order()
|
||||||
|
|
||||||
|
# Errors when context manager not enabled
|
||||||
|
t = torch.tensor(1., requires_grad=True).clone().sin().exp()
|
||||||
|
t.register_hook(hook)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "expects the current backward to be executed with multithreading disabled"):
|
||||||
|
t.backward()
|
||||||
|
|
||||||
def test_profiler(self):
|
def test_profiler(self):
|
||||||
x = torch.randn(10, 10)
|
x = torch.randn(10, 10)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -813,6 +813,28 @@ PyObject* THPModule_willEngineExecuteNode(PyObject* _unused, PyObject* arg) {
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyObject* THPModule_getCurrentGraphTaskExecutionOrder(
|
||||||
|
PyObject* _unused,
|
||||||
|
PyObject* noargs) {
|
||||||
|
HANDLE_TH_ERRORS
|
||||||
|
std::vector<torch::autograd::Node*> nodes =
|
||||||
|
torch::autograd::get_current_graph_task_execution_order();
|
||||||
|
TORCH_CHECK(
|
||||||
|
nodes.size(),
|
||||||
|
"_current_graph_task_execution_order should only be called during the backward pass");
|
||||||
|
auto list = THPObjectPtr(PyList_New(nodes.size()));
|
||||||
|
if (!list)
|
||||||
|
return nullptr;
|
||||||
|
for (const auto i : c10::irange(nodes.size())) {
|
||||||
|
// This node is guaranteed to be alive since the backward is still running
|
||||||
|
PyObject* pyobj_node =
|
||||||
|
torch::autograd::functionToPyObject(nodes[i]->getptr());
|
||||||
|
PyList_SET_ITEM(list.get(), i, pyobj_node);
|
||||||
|
}
|
||||||
|
return list.release();
|
||||||
|
END_HANDLE_TH_ERRORS
|
||||||
|
}
|
||||||
|
|
||||||
PyObject* THPModule_getCurrentGraphTaskId(PyObject* _unused, PyObject* noargs) {
|
PyObject* THPModule_getCurrentGraphTaskId(PyObject* _unused, PyObject* noargs) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
return THPUtils_packInt64(torch::autograd::get_current_graph_task_id());
|
return THPUtils_packInt64(torch::autograd::get_current_graph_task_id());
|
||||||
|
|
@ -1019,6 +1041,10 @@ static PyMethodDef TorchMethods[] = {
|
||||||
THPModule_willEngineExecuteNode,
|
THPModule_willEngineExecuteNode,
|
||||||
METH_O,
|
METH_O,
|
||||||
nullptr},
|
nullptr},
|
||||||
|
{"_current_graph_task_execution_order",
|
||||||
|
THPModule_getCurrentGraphTaskExecutionOrder,
|
||||||
|
METH_NOARGS,
|
||||||
|
nullptr},
|
||||||
{"_current_graph_task_id",
|
{"_current_graph_task_id",
|
||||||
THPModule_getCurrentGraphTaskId,
|
THPModule_getCurrentGraphTaskId,
|
||||||
METH_NOARGS,
|
METH_NOARGS,
|
||||||
|
|
|
||||||
|
|
@ -398,6 +398,66 @@ void add_node_to_current_graph_task_exec_info(Node* fn) {
|
||||||
current_graph_task->exec_info_[fn].needed_ = true;
|
current_graph_task->exec_info_[fn].needed_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NB: The engine itself does not use the outputs of this function.
|
||||||
|
std::vector<Node*> get_current_graph_task_execution_order() {
|
||||||
|
std::shared_ptr<GraphTask> task = current_graph_task;
|
||||||
|
if (!task) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
// We could potentially check if there is only a single device here
|
||||||
|
// but explicitly require this context doens't seem bad either
|
||||||
|
TORCH_CHECK(
|
||||||
|
!c10::AutogradState::get_tls_state().get_multithreading_enabled(),
|
||||||
|
"get_current_graph_task_execution_order expects the current backward to be "
|
||||||
|
"executed with multithreading disabled, e.g. by running:\n\n"
|
||||||
|
">>> with torch.autograd.set_multithreading_enabled(False):\n"
|
||||||
|
"... torch.autograd.grad(...)\n");
|
||||||
|
|
||||||
|
const bool check_exec_info = !task->exec_info_.empty();
|
||||||
|
std::vector<Node*> out{};
|
||||||
|
std::unordered_set<Node*> seen{};
|
||||||
|
|
||||||
|
auto compare_seq_nr = [](Node* n1, Node* n2) {
|
||||||
|
return n1->sequence_nr() < n2->sequence_nr();
|
||||||
|
};
|
||||||
|
std::priority_queue<Node*, std::vector<Node*>, decltype(compare_seq_nr)> heap(
|
||||||
|
compare_seq_nr);
|
||||||
|
|
||||||
|
for (Node* ptr : task->graph_roots_) {
|
||||||
|
heap.push(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementation notes:
|
||||||
|
// - Don't need to count dependencies because we have sequence_nr
|
||||||
|
// - Don't need to check topological_nr because we have exec_info
|
||||||
|
while (!heap.empty()) {
|
||||||
|
Node* fn = heap.top();
|
||||||
|
heap.pop();
|
||||||
|
|
||||||
|
const bool was_inserted = seen.insert(fn).second;
|
||||||
|
if (!was_inserted) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
out.push_back(fn);
|
||||||
|
for (const auto& edge : fn->next_edges()) {
|
||||||
|
Node* next_ptr = edge.function.get();
|
||||||
|
if (!next_ptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (check_exec_info) {
|
||||||
|
auto it = task->exec_info_.find(next_ptr);
|
||||||
|
if (it == task->exec_info_.end() || !it->second.should_execute()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
heap.push(next_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE: graph_tasks do not necessarily form a stack. Imagine this
|
// NOTE: graph_tasks do not necessarily form a stack. Imagine this
|
||||||
// case:
|
// case:
|
||||||
//
|
//
|
||||||
|
|
@ -1050,7 +1110,7 @@ auto Engine::compute_dependencies(
|
||||||
}
|
}
|
||||||
|
|
||||||
auto Engine::execute(
|
auto Engine::execute(
|
||||||
const edge_list& roots,
|
const edge_list& root_edges,
|
||||||
const variable_list& inputs,
|
const variable_list& inputs,
|
||||||
bool keep_graph,
|
bool keep_graph,
|
||||||
bool create_graph,
|
bool create_graph,
|
||||||
|
|
@ -1058,9 +1118,9 @@ auto Engine::execute(
|
||||||
const edge_list& outputs) -> variable_list {
|
const edge_list& outputs) -> variable_list {
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||||
validate_outputs(
|
validate_outputs(
|
||||||
roots, const_cast<variable_list&>(inputs), [](const std::string& msg) {
|
root_edges,
|
||||||
return msg;
|
const_cast<variable_list&>(inputs),
|
||||||
});
|
[](const std::string& msg) { return msg; });
|
||||||
if (accumulate_grad && create_graph) {
|
if (accumulate_grad && create_graph) {
|
||||||
TORCH_WARN_ONCE(
|
TORCH_WARN_ONCE(
|
||||||
"Using backward() with create_graph=True will create a reference cycle "
|
"Using backward() with create_graph=True will create a reference cycle "
|
||||||
|
|
@ -1083,17 +1143,25 @@ auto Engine::execute(
|
||||||
init_local_ready_queue();
|
init_local_ready_queue();
|
||||||
bool not_reentrant_backward_call = worker_device == NO_DEVICE;
|
bool not_reentrant_backward_call = worker_device == NO_DEVICE;
|
||||||
|
|
||||||
|
// Store root nodes so we can traverse through the graph later
|
||||||
|
// e.g., for get_current_graph_task_execution_order
|
||||||
|
c10::SmallVector<Node*, 4> temp_roots{root_edges.size()};
|
||||||
|
for (const auto i : c10::irange(root_edges.size())) {
|
||||||
|
temp_roots[i] = root_edges[i].function.get();
|
||||||
|
}
|
||||||
|
|
||||||
auto graph_task = std::make_shared<GraphTask>(
|
auto graph_task = std::make_shared<GraphTask>(
|
||||||
/* keep_graph */ keep_graph,
|
/* keep_graph */ keep_graph,
|
||||||
/* create_graph */ create_graph,
|
/* create_graph */ create_graph,
|
||||||
/* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
|
/* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
|
||||||
/* cpu_ready_queue */ local_ready_queue);
|
/* cpu_ready_queue */ local_ready_queue,
|
||||||
|
/* graph_roots */ std::move(temp_roots));
|
||||||
|
|
||||||
// If we receive a single root, skip creating extra root node
|
// If we receive a single root, skip creating extra root node
|
||||||
bool skip_dummy_node = roots.size() == 1;
|
bool skip_dummy_node = root_edges.size() == 1;
|
||||||
auto graph_root = skip_dummy_node
|
auto graph_root = skip_dummy_node
|
||||||
? roots.at(0).function
|
? root_edges.at(0).function
|
||||||
: std::make_shared<GraphRoot>(roots, inputs);
|
: std::make_shared<GraphRoot>(root_edges, inputs);
|
||||||
|
|
||||||
auto min_topo_nr = compute_min_topological_nr(outputs);
|
auto min_topo_nr = compute_min_topological_nr(outputs);
|
||||||
// Now compute the dependencies for all executable functions
|
// Now compute the dependencies for all executable functions
|
||||||
|
|
@ -1106,14 +1174,17 @@ auto Engine::execute(
|
||||||
|
|
||||||
// Queue the root
|
// Queue the root
|
||||||
if (skip_dummy_node) {
|
if (skip_dummy_node) {
|
||||||
InputBuffer input_buffer(roots.at(0).function->num_inputs());
|
InputBuffer input_buffer(root_edges.at(0).function->num_inputs());
|
||||||
auto input = inputs.at(0);
|
auto input = inputs.at(0);
|
||||||
|
|
||||||
const auto input_stream = InputMetadata(input).stream();
|
const auto input_stream = InputMetadata(input).stream();
|
||||||
const auto opt_next_stream =
|
const auto opt_next_stream =
|
||||||
roots.at(0).function->stream(c10::DeviceType::CUDA);
|
root_edges.at(0).function->stream(c10::DeviceType::CUDA);
|
||||||
input_buffer.add(
|
input_buffer.add(
|
||||||
roots.at(0).input_nr, std::move(input), input_stream, opt_next_stream);
|
root_edges.at(0).input_nr,
|
||||||
|
std::move(input),
|
||||||
|
input_stream,
|
||||||
|
opt_next_stream);
|
||||||
|
|
||||||
execute_with_graph_task(graph_task, graph_root, std::move(input_buffer));
|
execute_with_graph_task(graph_task, graph_root, std::move(input_buffer));
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -143,6 +143,9 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||||
Node& operator=(Node&& other) = delete;
|
Node& operator=(Node&& other) = delete;
|
||||||
virtual ~Node() = default;
|
virtual ~Node() = default;
|
||||||
|
|
||||||
|
std::shared_ptr<Node> getptr() {
|
||||||
|
return shared_from_this();
|
||||||
|
}
|
||||||
/// Evaluates the function on the given inputs and returns the result of the
|
/// Evaluates the function on the given inputs and returns the result of the
|
||||||
/// function call.
|
/// function call.
|
||||||
variable_list operator()(variable_list&& inputs) {
|
variable_list operator()(variable_list&& inputs) {
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@ struct GraphTask : std::enable_shared_from_this<GraphTask> {
|
||||||
|
|
||||||
// Records the nodes that are in the graph
|
// Records the nodes that are in the graph
|
||||||
std::unordered_set<Node*> nodes_in_graph_;
|
std::unordered_set<Node*> nodes_in_graph_;
|
||||||
|
c10::SmallVector<Node*, 4> graph_roots_;
|
||||||
// Note [Exec info]
|
// Note [Exec info]
|
||||||
// Exec info is created for each GraphTask, which allows filtering paths on
|
// Exec info is created for each GraphTask, which allows filtering paths on
|
||||||
// the graph that are not needed. It has a bit complicated semantics. If it's
|
// the graph that are not needed. It has a bit complicated semantics. If it's
|
||||||
|
|
@ -164,8 +165,10 @@ struct GraphTask : std::enable_shared_from_this<GraphTask> {
|
||||||
bool grad_mode,
|
bool grad_mode,
|
||||||
int reentrant_depth,
|
int reentrant_depth,
|
||||||
std::shared_ptr<ReadyQueue> cpu_ready_queue,
|
std::shared_ptr<ReadyQueue> cpu_ready_queue,
|
||||||
|
c10::SmallVector<Node*, 4> graph_roots,
|
||||||
bool exit_on_error = false)
|
bool exit_on_error = false)
|
||||||
: keep_graph_(keep_graph),
|
: keep_graph_(keep_graph),
|
||||||
|
graph_roots_(std::move(graph_roots)),
|
||||||
owner_(NO_DEVICE),
|
owner_(NO_DEVICE),
|
||||||
reentrant_depth_(reentrant_depth),
|
reentrant_depth_(reentrant_depth),
|
||||||
exit_on_error_(exit_on_error),
|
exit_on_error_(exit_on_error),
|
||||||
|
|
@ -198,6 +201,7 @@ get_current_graph_task_exec_info();
|
||||||
TORCH_API const std::unordered_set<Node*>*
|
TORCH_API const std::unordered_set<Node*>*
|
||||||
get_current_graph_task_nodes_in_graph();
|
get_current_graph_task_nodes_in_graph();
|
||||||
TORCH_API bool get_current_graph_task_keep_graph();
|
TORCH_API bool get_current_graph_task_keep_graph();
|
||||||
|
TORCH_API std::vector<Node*> get_current_graph_task_execution_order();
|
||||||
TORCH_API int get_current_graph_task_id();
|
TORCH_API int get_current_graph_task_id();
|
||||||
void add_node_to_current_graph_task_exec_info(Node* fn);
|
void add_node_to_current_graph_task_exec_info(Node* fn);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -185,6 +185,13 @@ void DistEngine::computeDependencies(
|
||||||
bool retainGraph) {
|
bool retainGraph) {
|
||||||
TORCH_INTERNAL_ASSERT(graphRoot, "graphRoot is null!");
|
TORCH_INTERNAL_ASSERT(graphRoot, "graphRoot is null!");
|
||||||
|
|
||||||
|
// Store root nodes so we can traverse through the graph later
|
||||||
|
// e.g., for get_current_graph_task_execution_order
|
||||||
|
c10::SmallVector<Node*, 4> temp_roots{rootEdges.size()};
|
||||||
|
for (const auto i : c10::irange(rootEdges.size())) {
|
||||||
|
temp_roots[i] = rootEdges[i].function.get();
|
||||||
|
}
|
||||||
|
|
||||||
// Build the graph task and graph root.
|
// Build the graph task and graph root.
|
||||||
// NOTE: we don't need to build and pass a cpu_ready_queue to GraphTask
|
// NOTE: we don't need to build and pass a cpu_ready_queue to GraphTask
|
||||||
// as we use execute_graph_task_until_ready_queue_empty, which will build
|
// as we use execute_graph_task_until_ready_queue_empty, which will build
|
||||||
|
|
@ -194,6 +201,7 @@ void DistEngine::computeDependencies(
|
||||||
/* create_graph */ false,
|
/* create_graph */ false,
|
||||||
/* depth */ 0,
|
/* depth */ 0,
|
||||||
/* cpu_ready_queue */ global_cpu_ready_queue_,
|
/* cpu_ready_queue */ global_cpu_ready_queue_,
|
||||||
|
/* graph_roots */ temp_roots,
|
||||||
/* exit_on_error */ true);
|
/* exit_on_error */ true);
|
||||||
|
|
||||||
// Run BFS to traverse the graph locally. The roots of the graph are
|
// Run BFS to traverse the graph locally. The roots of the graph are
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user