[Composable API] Add fully_shard debug function to print sharded tree structure, module names and managed param fqns (#99133)

Adding a fully_shard debug function to print sharded tree structure like following format, return module names and their managed parameter fqns as well.

![Screenshot 2023-04-18 at 5 14 54 PM](https://user-images.githubusercontent.com/48731194/232931628-169a63a9-b4d5-4902-9cfd-f40113f3ec98.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99133
Approved by: https://github.com/rohan-varma
This commit is contained in:
Yanli Zhao 2023-04-19 02:43:18 +00:00 committed by PyTorch MergeBot
parent 6b6dc4418d
commit 6ca991cacf
4 changed files with 227 additions and 8 deletions

View File

@ -0,0 +1,115 @@
# Owner(s): ["oncall: distributed"]
import sys
import torch
import torch.distributed as dist
from torch.distributed._composable import fully_shard
from torch.distributed.fsdp._debug_utils import (
_get_sharded_module_tree_with_module_name_to_fqns,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.testing._internal.common_dist_composable import CompositeModel, UnitModule
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestUtils(FSDPTest):
@property
def world_size(self):
return 2
@property
def process_group(self):
return dist.distributed_c10d._get_default_group()
@skip_if_lt_x_gpu(2)
def test_get_sharded_module_tree_with_module_name_to_fqns(self):
model = CompositeModel(torch.device("cuda"))
fully_shard(
model,
policy=ModuleWrapPolicy({UnitModule}),
)
(
sharded_tree_info,
sharded_module_name_to_fqns,
) = _get_sharded_module_tree_with_module_name_to_fqns(model)
self.assertEqual(
list(sharded_module_name_to_fqns.keys()),
["[CompositeModel]", "u1[UnitModule]", "u2[UnitModule]"],
)
self.assertEqual(
list(sharded_module_name_to_fqns.values()),
[
["l1.weight", "l1.bias", "l2.weight", "l2.bias"],
[
"u1.l1.weight",
"u1.l1.bias",
"u1.seq.1.weight",
"u1.seq.1.bias",
"u1.l2.weight",
"u1.l2.bias",
],
[
"u2.l1.weight",
"u2.l1.bias",
"u2.seq.1.weight",
"u2.seq.1.bias",
"u2.l2.weight",
"u2.l2.bias",
],
],
)
# Test nested fully_shard
new_model = CompositeModel(torch.device("cuda"))
fully_shard(new_model.u1)
fully_shard(new_model)
(
sharded_tree_info,
sharded_module_name_to_fqns,
) = _get_sharded_module_tree_with_module_name_to_fqns(new_model)
self.assertEqual(
list(sharded_module_name_to_fqns.keys()),
["[CompositeModel]", "u1[UnitModule]"],
)
self.assertEqual(
list(sharded_module_name_to_fqns.values()),
[
[
"l1.weight",
"l1.bias",
"u2.l1.weight",
"u2.l1.bias",
"u2.seq.1.weight",
"u2.seq.1.bias",
"u2.l2.weight",
"u2.l2.bias",
"l2.weight",
"l2.bias",
],
[
"u1.l1.weight",
"u1.l1.bias",
"u1.seq.1.weight",
"u1.seq.1.bias",
"u1.l2.weight",
"u1.l2.bias",
],
],
)
if __name__ == "__main__":
run_tests()

View File

@ -58,7 +58,7 @@ class _FSDPState(_State):
self._is_root: Optional[bool] = None self._is_root: Optional[bool] = None
self._handles: List[flat_param_file.FlatParamHandle] = [] self._handles: List[flat_param_file.FlatParamHandle] = []
self._fully_sharded_module_to_handles: Dict[ self._fully_sharded_module_to_handles: Dict[
nn.Module, flat_param_file.FlatParamHandle nn.Module, List[flat_param_file.FlatParamHandle]
] = {} ] = {}
self.compute_device: Optional[torch.device] = None self.compute_device: Optional[torch.device] = None
# All following attributes should only be used for root states: # All following attributes should only be used for root states:
@ -204,7 +204,7 @@ def _get_param_to_fqns(
includes the FQNs across all encounters. (Default: ``True``) includes the FQNs across all encounters. (Default: ``True``)
""" """
def module_fn(module, prefix, param_to_fqns): def module_fn(module, prefix, tree_level, param_to_fqns):
for param_name, param in module.named_parameters(recurse=False): for param_name, param in module.named_parameters(recurse=False):
local_fqns = ( local_fqns = (
param._fqns param._fqns
@ -272,13 +272,14 @@ def _apply_to_modules(
to remove the prefix. to remove the prefix.
""" """
def f(module: torch.nn.Module, prefix: str, *args, **kwargs): def f(module: torch.nn.Module, prefix: str, tree_level: int, *args, **kwargs):
# Call the module function before recursing over children (pre-order) # Call the module function before recursing over children (pre-order)
module_fn(module, prefix, *args, **kwargs) module_fn(module, prefix, tree_level, *args, **kwargs)
for submodule_name, submodule in module.named_children(): for submodule_name, submodule in module.named_children():
if submodule is None: if submodule is None:
continue continue
new_prefix = prefix + submodule_name + "." new_prefix = prefix + submodule_name + "."
new_tree_level = tree_level + 1
if filter_fqns is not None: if filter_fqns is not None:
for fqn in filter_fqns: for fqn in filter_fqns:
if fqn.startswith(new_prefix): if fqn.startswith(new_prefix):
@ -308,9 +309,9 @@ def _apply_to_modules(
f"submodule_name = {submodule_name}" f"submodule_name = {submodule_name}"
) )
new_prefix = prefix new_prefix = prefix
f(submodule, new_prefix, *args, **kwargs) f(submodule, new_prefix, new_tree_level, *args, **kwargs)
f(root_module, "", *args, **kwargs) f(root_module, "", 0, *args, **kwargs)
return return_fn(*args, **kwargs) return return_fn(*args, **kwargs)

View File

@ -0,0 +1,103 @@
from typing import Dict, List, Tuple
import torch
import torch.distributed.fsdp.flat_param as flat_param_file
from torch.distributed.fsdp._common_utils import (
_apply_to_modules,
_get_module_fsdp_state,
clean_tensor_name,
)
def _get_sharded_module_tree_with_module_name_to_fqns(
model: torch.nn.Module,
) -> Tuple[str, Dict[str, List[str]]]:
"""
It is used for composable fully_shard() code path, it returns
1. sharded module tree info: each line reprents a submodule name that contats the
submodule's FQN and its submodule class name, if the submodule is sharded by `fully_shard`,
the submodule name will add a postfix with ' FULLY SHARDED'. Each increased tree
level adds 4 spaces before the printed name. A printed sharded module tree info for a toy model
is like this:
[CompositeModel] FULLY SHARDED
l1[Linear]
u1[UnitModule] FULLY SHARDED
u1.l1[Linear]
u1.seq[Sequential]
u1.seq.0[ReLU]
u1.seq.1[Linear]
u1.seq.2[ReLU]
u1.l2[Linear]
u2[UnitModule] FULLY SHARDED
u2.l1[Linear]
u2.seq[Sequential]
u2.seq.0[ReLU]
u2.seq.1[Linear]
u2.seq.2[ReLU]
u2.l2[Linear]
l2[Linear]
2. a dict mapping from the concated module FQN and class name to a list of its managed
original parameters' FQNs. An example of the dict for the above toy sharded model is like this:
{'[CompositeModel]': ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'],
'u1[UnitModule]': ['u1.l1.weight', 'u1.l1.bias', 'u1.seq.1.weight', 'u1.seq.1.bias', 'u1.l2.weight', 'u1.l2.bias'],
'u2[UnitModule]': ['u2.l1.weight', 'u2.l1.bias', 'u2.seq.1.weight', 'u2.seq.1.bias', 'u2.l2.weight', 'u2.l2.bias']
}
All FQNs are prefixed starting from ``model``.
Args:
model (torch.nn.Module): Root module (which may or may not be passed to
composable `fully_shard()`).
"""
def module_fn(
module, prefix, tree_level, sharded_tree_info, sharded_module_name_to_fqns
):
num_spaces = tree_level * 4
trimed_prefix = (
prefix[:-1] if (len(prefix) > 0 and prefix[-1] == ".") else prefix
)
prefixed_module_name = trimed_prefix + "[" + module.__class__.__name__ + "]"
printed_prefixed_module_name = " " * num_spaces + prefixed_module_name
state = _get_module_fsdp_state(module)
if state is None:
sharded_tree_info[0] += printed_prefixed_module_name + "\n"
return
handles = state._fully_sharded_module_to_handles.get(module, [])
if handles:
sharded_tree_info[0] += (
printed_prefixed_module_name + " FULLY SHARDED" + "\n"
)
else:
sharded_tree_info[0] += printed_prefixed_module_name + "\n"
for handle in handles:
param = handle.flat_param
assert type(param) is flat_param_file.FlatParameter
global_fqns = [
clean_tensor_name(prefix + name) for name in param._fqns
] # prefixed from the top level `model` (i.e. including `prefix`)
if prefixed_module_name in sharded_module_name_to_fqns:
sharded_module_name_to_fqns[prefixed_module_name].extend(global_fqns)
else:
sharded_module_name_to_fqns[prefixed_module_name] = global_fqns
def return_fn(sharded_tree_info, sharded_module_name_to_fqns):
return sharded_tree_info[0], sharded_module_name_to_fqns
# Use List to mutate its value in place while running the recursive functions
sharded_tree_info: List[str] = [
"",
]
sharded_module_name_to_fqns: Dict[str, List[str]] = {}
return _apply_to_modules(
model,
module_fn,
return_fn,
[key for key, _ in model.named_parameters()],
sharded_tree_info,
sharded_module_name_to_fqns,
)

View File

@ -1100,7 +1100,7 @@ def _get_param_id_to_param_from_optim_input(
def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[nn.Parameter, str]: def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[nn.Parameter, str]:
def module_fn(module, prefix, flat_param_to_fqn): def module_fn(module, prefix, tree_level, flat_param_to_fqn):
for param_name, param in module.named_parameters(recurse=False): for param_name, param in module.named_parameters(recurse=False):
if type(param) is not FlatParameter: if type(param) is not FlatParameter:
continue continue
@ -1533,7 +1533,7 @@ def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]:
to unique parameters. to unique parameters.
""" """
def module_fn(module, prefix, fqn_to_param_info): def module_fn(module, prefix, tree_level, fqn_to_param_info):
fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
if fsdp_state is None: if fsdp_state is None:
return return