pytorch/torch/distributed/fsdp/_debug_utils.py
Yanli Zhao 6ca991cacf [Composable API] Add fully_shard debug function to print sharded tree structure, module names and managed param fqns (#99133)
Adding a fully_shard debug function to print sharded tree structure like following format, return module names and their managed parameter fqns as well.

![Screenshot 2023-04-18 at 5 14 54 PM](https://user-images.githubusercontent.com/48731194/232931628-169a63a9-b4d5-4902-9cfd-f40113f3ec98.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99133
Approved by: https://github.com/rohan-varma
2023-04-19 19:27:43 +00:00

104 lines
4.1 KiB
Python

from typing import Dict, List, Tuple
import torch
import torch.distributed.fsdp.flat_param as flat_param_file
from torch.distributed.fsdp._common_utils import (
_apply_to_modules,
_get_module_fsdp_state,
clean_tensor_name,
)
def _get_sharded_module_tree_with_module_name_to_fqns(
model: torch.nn.Module,
) -> Tuple[str, Dict[str, List[str]]]:
"""
It is used for composable fully_shard() code path, it returns
1. sharded module tree info: each line reprents a submodule name that contats the
submodule's FQN and its submodule class name, if the submodule is sharded by `fully_shard`,
the submodule name will add a postfix with ' FULLY SHARDED'. Each increased tree
level adds 4 spaces before the printed name. A printed sharded module tree info for a toy model
is like this:
[CompositeModel] FULLY SHARDED
l1[Linear]
u1[UnitModule] FULLY SHARDED
u1.l1[Linear]
u1.seq[Sequential]
u1.seq.0[ReLU]
u1.seq.1[Linear]
u1.seq.2[ReLU]
u1.l2[Linear]
u2[UnitModule] FULLY SHARDED
u2.l1[Linear]
u2.seq[Sequential]
u2.seq.0[ReLU]
u2.seq.1[Linear]
u2.seq.2[ReLU]
u2.l2[Linear]
l2[Linear]
2. a dict mapping from the concated module FQN and class name to a list of its managed
original parameters' FQNs. An example of the dict for the above toy sharded model is like this:
{'[CompositeModel]': ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'],
'u1[UnitModule]': ['u1.l1.weight', 'u1.l1.bias', 'u1.seq.1.weight', 'u1.seq.1.bias', 'u1.l2.weight', 'u1.l2.bias'],
'u2[UnitModule]': ['u2.l1.weight', 'u2.l1.bias', 'u2.seq.1.weight', 'u2.seq.1.bias', 'u2.l2.weight', 'u2.l2.bias']
}
All FQNs are prefixed starting from ``model``.
Args:
model (torch.nn.Module): Root module (which may or may not be passed to
composable `fully_shard()`).
"""
def module_fn(
module, prefix, tree_level, sharded_tree_info, sharded_module_name_to_fqns
):
num_spaces = tree_level * 4
trimed_prefix = (
prefix[:-1] if (len(prefix) > 0 and prefix[-1] == ".") else prefix
)
prefixed_module_name = trimed_prefix + "[" + module.__class__.__name__ + "]"
printed_prefixed_module_name = " " * num_spaces + prefixed_module_name
state = _get_module_fsdp_state(module)
if state is None:
sharded_tree_info[0] += printed_prefixed_module_name + "\n"
return
handles = state._fully_sharded_module_to_handles.get(module, [])
if handles:
sharded_tree_info[0] += (
printed_prefixed_module_name + " FULLY SHARDED" + "\n"
)
else:
sharded_tree_info[0] += printed_prefixed_module_name + "\n"
for handle in handles:
param = handle.flat_param
assert type(param) is flat_param_file.FlatParameter
global_fqns = [
clean_tensor_name(prefix + name) for name in param._fqns
] # prefixed from the top level `model` (i.e. including `prefix`)
if prefixed_module_name in sharded_module_name_to_fqns:
sharded_module_name_to_fqns[prefixed_module_name].extend(global_fqns)
else:
sharded_module_name_to_fqns[prefixed_module_name] = global_fqns
def return_fn(sharded_tree_info, sharded_module_name_to_fqns):
return sharded_tree_info[0], sharded_module_name_to_fqns
# Use List to mutate its value in place while running the recursive functions
sharded_tree_info: List[str] = [
"",
]
sharded_module_name_to_fqns: Dict[str, List[str]] = {}
return _apply_to_modules(
model,
module_fn,
return_fn,
[key for key, _ in model.named_parameters()],
sharded_tree_info,
sharded_module_name_to_fqns,
)