pytorch/torch/distributed/pipelining/_backward.py
Howard Huang 141cae2eb8 [pipelining] Fix more leaks and check leaks in tests (#136584)
Fix two more leaks of the same variety as #136507 (see that PR desc and attached gdoc for debug details).

This time, also add a test-time check that helped to discover new leaks and ensure we won't accidently regress.

Adds `check_tensor_leak` util which internally asserts no tensors are being kept alive by other objects involved in py ref cycles.

Uses objgraph for a nice debug utility when a leak is found.

Credit to @H-Huang for pointing out objdump and helping debug the 'param_group["intermediates"]` leak.

I manually confirmed that all 3 of the leaks identified/fixed so far are caught by the unit test and checker.

Sample output, if I re-introduce a leak by commenting out `del param_group["intermediates"]` in _backward.py,
and run `python test/distributed/pipelining/test_schedule_multiproc.py -k test_schedule_with_native_zero_bubble`:

```
warnings.warn(
/data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5341: UserWarning: 34 tensors were found in the garbage. Did you introduce a reference cycle?
warnings.warn(
/data/users/whc/pytorch/torch/testing/_internal/common_utils.py:5347: UserWarning: Dumping first 1 objgraphs of leaked tensors rendered to png
Graph written to /tmp/objgraph-ztz642h3.dot (19 nodes)
Graph viewer (xdot) not found, generating a png instead
Image generated as /tmp/objgraph-ztz642h3.png
```

rendering of ` /tmp/objgraph-ztz642h3.png`:
<img width="1671" alt="image" src="https://github.com/user-attachments/assets/9098ff29-224c-4533-935b-83c210ac2e22">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136584
Approved by: https://github.com/kwen2501, https://github.com/H-Huang
ghstack dependencies: #136507

Co-authored-by: Howard Huang <howardhuang@fb.com>
2024-09-26 01:10:40 +00:00

399 lines
15 KiB
Python

# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import collections
import logging
import weakref
from typing import Any, cast, Deque, Dict, Iterator, List, Optional, Set, Tuple, Union
import torch
from torch.autograd.graph import GradientEdge, Node
from torch.nn import Parameter
from ._debug import map_debug_info
logger = logging.getLogger(__name__)
def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Union[Node, None]:
"""
Get the grad function or grad accumulator for a tensor.
Accumulate grad nodes are lazily created, so we need to a
dummy view in order to trigger its creation.
"""
if t.requires_grad and t.grad_fn is None:
# if no grad function (leaf tensors) we use view
viewed_t = t.view_as(t)
grad_fn = viewed_t.grad_fn
if grad_fn is not None:
return grad_fn.next_functions[0][0]
else:
raise RuntimeError(
"Attempted to get grad_fn, but got None."
"Is this being created in a no-grad context?"
)
else:
return t.grad_fn
def reverse_closure(
roots: List[Node], target_nodes: Set[Node]
) -> Tuple[Set[Node], Set[Node]]:
"""
This function returns the reverse closure of the given roots,
i.e. the set of nodes that can be reached from the roots by following the
reverse edges of the graph. The target_nodes are the nodes that we want to
include in the closure.
"""
# Recurse until we reach a target node
closure: Set[Node] = set()
visited_target_nodes = set()
q: Deque[Node] = collections.deque()
for node in roots:
if node is not None and node not in closure:
closure.add(node)
q.append(node)
while q:
node = q.popleft()
metadata = cast(Dict[str, List], node.metadata)
reverse_edges = metadata.get("reverse_edges", [])
for holder_ref, idx in reverse_edges:
ref = holder_ref()
if ref is None:
# this reverse graph is no longer alive
# raise RuntimeError("Reverse graph is no longer alive")
continue
fn = ref.node
if fn in closure or fn is None:
continue
if fn in target_nodes:
visited_target_nodes.add(fn)
continue
closure.add(fn)
q.append(fn)
return closure, visited_target_nodes
# Enable weak pointer
class Holder:
def __init__(self, node: Node):
self.node = node
def construct_reverse_graph(roots: List[Node]) -> List[Holder]:
q: Deque[Node] = collections.deque()
root_seen: Set[Node] = set()
reverse_graph_refs: List[Holder] = []
for node in roots:
if node is not None and node not in root_seen:
q.append(node)
root_seen.add(node)
while q:
node = q.popleft()
for fn, idx in node.next_functions:
if fn is not None:
# Don't necessarily need to store on the graph
metadata = cast(Dict[str, List], fn.metadata)
reverse_edges = metadata.get("reverse_edges", [])
if len(reverse_edges) == 0:
q.append(fn)
holder = Holder(node)
holder_ref = weakref.ref(holder)
reverse_graph_refs.append(holder)
reverse_edges.append((holder_ref, idx))
metadata["reverse_edges"] = reverse_edges
return reverse_graph_refs
def get_param_groups(inputs: List[Node], params: List[Node]) -> List[Dict[str, Any]]:
"""
Given a list of inputs and a list of parameters, return a list of parameter
groups, where each group contains the parameters and the intermediates that
are connected to the parameters.
The returned list of parameter groups is a list of dictionaries, where each
dictionary contains the following keys:
- "params": a set of parameters
- "intermediates": a set of intermediates
The returned list of parameter groups is a list of dictionaries,
"""
# reverse graph that starts with inputs, and goes up to the dOutput or the loss,
# but omits weights and any subgraphs connecting weights to this closure
inputs_closure, _ = reverse_closure(inputs, set())
param_groups: Dict[Node, Dict[str, Set]] = dict() # keyed on intermediates
for i, param in enumerate(params):
closure, intersected = reverse_closure([param], inputs_closure)
param_group: Dict[str, Set] = {
"params": {param},
"intermediates": intersected,
}
for input_node in intersected:
existing = param_groups.get(input_node, None)
if existing is not None:
existing["params"] = existing["params"].union(param_group["params"])
existing["intermediates"] = existing["intermediates"].union(
param_group["intermediates"]
)
param_group = existing
else:
param_groups[input_node] = param_group
# Sanity check: union of all param_groups params should be equal to all params
union_params: Set[Node] = set()
seen_ids: Set[int] = set()
unique_param_groups = []
for param_group in param_groups.values():
if id(param_group) not in seen_ids:
seen_ids.add(id(param_group))
unique_param_groups.append(param_group)
union_params = union_params.union(param_group["params"])
# The assert will only be true if the input tensor requires gradients,
# otherwise the autograd graph will miss the first layer of inputs
# assert union_params == set(params)
return unique_param_groups
def stage_backward_input(
stage_outputs: List[torch.Tensor],
output_grads: Optional[List[torch.Tensor]],
input_values: List[torch.Tensor],
weights: Iterator[Parameter],
):
"""
compute the gradients for only the stage inputs with respect to the stage outputs
"""
stage_output_grad_fns: List[Node] = list(
filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs))
)
stage_input_grad_fns: List[Node] = list(
filter(None, map(_get_grad_fn_or_grad_acc, input_values))
)
weight_grad_fns: List[Node] = list(
filter(None, map(_get_grad_fn_or_grad_acc, weights))
)
reverse_graph_refs = construct_reverse_graph(stage_output_grad_fns)
param_groups = get_param_groups(stage_input_grad_fns, weight_grad_fns)
del reverse_graph_refs
for param_group in param_groups:
for i, intermediate in enumerate(param_group["intermediates"]):
def get_hook(param_group, i):
def hook(grad_inputs):
if param_group.get("grads", None) is None:
param_group["grads"] = [None] * len(
param_group["intermediates"]
)
param_group["grads"][i] = grad_inputs
return hook
# These are always "split" nodes that we need to recompute, so
# save their inputs.
intermediate.register_prehook(get_hook(param_group, i))
# Stage 0 inputs do not require grads? Should we skip in that case?
if all(tensor.requires_grad for tensor in input_values):
if output_grads is None:
# In case this is the loss and there are no output_grads, then we just use 1s
output_grads = [
torch.ones_like(stage_output) for stage_output in stage_outputs
]
dinputs = torch.autograd.grad(
stage_outputs,
inputs=input_values,
grad_outputs=output_grads,
retain_graph=True,
)
# update the gradients for inputs
for i, inp in enumerate(input_values):
if inp.grad is None:
inp.grad = dinputs[i]
else:
inp.grad += dinputs[i]
else:
dinputs = None
return dinputs, param_groups
def stage_backward_weight(
weights: Iterator[Parameter], param_groups: List[Dict[str, Any]]
):
# map weights to param_group_weights
grad_acc_to_weight = {}
weight_grads = []
for index, weight in enumerate(weights):
grad_acc = _get_grad_fn_or_grad_acc(weight)
grad_acc_to_weight[grad_acc] = weight, index
weight_grads.append(weight.grad)
for param_group in param_groups:
# TODO: Handle case where intermediate can have multiple outputs
intermediate_edges = tuple(
GradientEdge(i, 0) for i in param_group["intermediates"]
)
weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
# Break a reference cycle caused inside stage_backward_input->get_hook->hook
# The summarized cycle is:
# `hook` -> cell -> param_group -> intermediates -> `hook`
# becuase we install the hook function onto each of the intermediate autograd nodes.
# We need to keep intermediates alive up until backward_weight, but we can free it now.
del param_group["intermediates"]
assert all(len(g) == 1 for g in param_group["grads"])
# [NEW!] Able to pass a GradientEdge to autograd.grad as output
# We do not need to retain_graph because... guarantee no overlap?
# print("trying to execute: ", intermediate_edges, weights_edges)
dweights = torch.autograd.grad(
intermediate_edges,
weights_edges,
grad_outputs=sum(param_group["grads"], tuple()),
)
for grad_acc, dw in zip(param_group["params"], dweights):
weight, index = grad_acc_to_weight[grad_acc]
if weight.grad is None:
weight.grad = dw
else:
weight.grad += dw
# return grads in the original order weights were provided in
return weight_grads
def stage_backward(
stage_output,
output_grads,
input_values,
outputs_with_grads_idxs: Optional[List[int]] = None, # deprecated, not used
):
"""
This is a helper function to:
1. compute the gradients for the stage inputs, and
2. accumulate gradients for the stage module's parameters.
Given the input value(s) and the corresponding gradient for the output
value(s), compute and accumulate gradients for all parameter values (leaves
in the autograd trace) as well as return a list of the gradients for the
input values
"""
if outputs_with_grads_idxs is not None:
# Deprecated, not used in runtime calls, only exists in compiler
stage_output = [stage_output[i] for i in outputs_with_grads_idxs]
output_grads = [output_grads[i] for i in outputs_with_grads_idxs]
try:
# stage_output may be a composite datatype like dict. Extract all individual
# tensor values here
stage_output_tensors: List[torch.Tensor] = []
output_grad_tensors: List[Optional[torch.Tensor]] = []
def extract_tensors_with_grads(
output_val,
grad_val,
# Don't delete me- see [Note: ref cycle]
extract_tensors_with_grads,
):
if isinstance(output_val, torch.Tensor):
if not output_val.requires_grad and output_val.grad_fn is None:
return
assert isinstance(
grad_val, (torch.Tensor, type(None))
), f"Expected Tensor or None gradient but got {type(grad_val)}"
stage_output_tensors.append(output_val)
output_grad_tensors.append(grad_val)
elif isinstance(output_val, (tuple, list)):
if grad_val is None:
return
assert isinstance(
grad_val, (tuple, list)
), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}"
assert len(output_val) == len(grad_val)
for ov, gv in zip(output_val, grad_val):
extract_tensors_with_grads(
ov,
gv,
extract_tensors_with_grads,
)
elif isinstance(output_val, dict):
if grad_val is None:
return
assert isinstance(grad_val, dict)
assert set(output_val.keys()) == set(grad_val.keys())
for k in output_val.keys():
extract_tensors_with_grads(
output_val[k], grad_val[k], extract_tensors_with_grads
)
else:
# Output is a non-tensor type; just ignore it
pass
# Note: ref cycle
# break a ref cycle that would keep tensors alive until GC runs
# 1. extract_tensors_with_grads refers to a cell that holds refs to any vars defined in stage_backward
# and used in extract_tensors_with_grads
# 2. extract_tensors_with_grads referred to both stage_output_tensors, output_grad_tensors,
# and to itself (extract_tensors_with_grads) since it makes a recursive call
# 3. stage_output_tensors was kept alive by the above refcycle, and it holds activation tensors, which is bad
# fix -> explictly pass in the ref to the fn, so there is no gc cycle anymore
extract_tensors_with_grads(
stage_output, output_grads, extract_tensors_with_grads
)
torch.autograd.backward(
stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type]
)
# Extract gradients wrt the input values
grad_inputs = []
for val in input_values:
if isinstance(val, torch.Tensor):
grad_inputs.append(val.grad)
else:
grad_inputs.append(None)
# Alternative impl: `torch.autograd.grad`.
# Note that `torch.autograd.grad` will not accumulate gradients into the
# model's parameters.
"""
inputs_with_grad = []
for val in input_values:
if isinstance(val, torch.Tensor) and val.requires_grad:
inputs_with_grad.append(val)
grad_inputs = torch.autograd.grad(
stage_output_tensors, inputs_with_grad, output_grad_tensors, # type: ignore[arg-type]
)
"""
except Exception as e:
exc_msg = f"""
Failed to run stage backward:
Stage output: {map_debug_info(stage_output)}
Output gradient: {map_debug_info(output_grads)}
Input: {map_debug_info(input_values)}
"""
raise RuntimeError(exc_msg) from e
return grad_inputs
# TODO: handling requires_grad=False dynamically. Can we analyze this during initial
# IR emission?
def _null_coalesce_accumulate(lhs, rhs):
"""
Coalesce two values, even if one of them is null, returning the non-null
value.
"""
if lhs is None:
return rhs
elif rhs is None:
return lhs
else:
return torch.add(lhs, rhs)