pytorch/torch/distributed/fsdp/_debug_utils.py
Chien-Chin Huang 591cb776af [FSDP][state_dict][optim_state_dict] Log slow optim and model state_dict paths (#108290)
This PR adds SimpleProfiler for FSDP state_dict/load_state_dict logging purpose. SimpleProfiler use class variables to record profiling results and it does everything in the Python which can be slow. So it is only suitable for logging slow actions such as initialization and state_dict/load_state_dict.

This PR uses SimpleProfiler to log some critical/slow paths of the model and optimizer state_dict/load_state_dict.

Differential Revision: [D48774406](https://our.internmc.facebook.com/intern/diff/D48774406/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108290
Approved by: https://github.com/wz337
2023-09-01 06:57:59 +00:00

152 lines
5.3 KiB
Python

import logging
import time
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum
from typing import Dict, Iterator, List, Set, Tuple
import torch
import torch.distributed.fsdp.flat_param as flat_param_file
from torch.distributed.fsdp._common_utils import (
_apply_to_modules,
_get_module_fsdp_state,
clean_tensor_name,
)
logger = logging.getLogger(__name__)
class SimpleProfiler:
class Type(str, Enum):
ALL = "all"
ALLGATHER = "all_gather"
ALLGATHER_OBJ = "all_gather_object"
RESHARDING = "resharding"
H2D = "H2D"
D2H = "D2H"
results: Dict[str, float] = defaultdict(float)
profiling: Set[str] = set()
@classmethod
def reset(cls) -> None:
cls.results.clear()
cls.profiling.clear()
@classmethod
@contextmanager
def profile(cls, profile_type: str) -> Iterator[None]:
assert profile_type not in cls.profiling, (
f"{profile_type} is already being profiled. "
"SimpleProfiler does not support profiling multiple instances at "
"the same time. "
)
cls.profiling.add(profile_type)
begin = time.monotonic()
try:
yield
finally:
end = time.monotonic()
cls.results[profile_type] += end - begin
cls.profiling.remove(profile_type)
@classmethod
def dump_and_reset(cls, msg: str) -> None:
logger.warning("%s %s", msg, str(cls.results))
cls.reset()
def _get_sharded_module_tree_with_module_name_to_fqns(
model: torch.nn.Module,
) -> Tuple[str, Dict[str, List[str]]]:
"""
It is used for composable fully_shard() code path, it returns
1. sharded module tree info: each line reprents a submodule name that contats the
submodule's FQN and its submodule class name, if the submodule is sharded by `fully_shard`,
the submodule name will add a postfix with ' FULLY SHARDED'. Each increased tree
level adds 4 spaces before the printed name. A printed sharded module tree info for a toy model
is like this:
[CompositeModel] FULLY SHARDED
l1[Linear]
u1[UnitModule] FULLY SHARDED
u1.l1[Linear]
u1.seq[Sequential]
u1.seq.0[ReLU]
u1.seq.1[Linear]
u1.seq.2[ReLU]
u1.l2[Linear]
u2[UnitModule] FULLY SHARDED
u2.l1[Linear]
u2.seq[Sequential]
u2.seq.0[ReLU]
u2.seq.1[Linear]
u2.seq.2[ReLU]
u2.l2[Linear]
l2[Linear]
2. a dict mapping from the concated module FQN and class name to a list of its managed
original parameters' FQNs. An example of the dict for the above toy sharded model is like this:
{'[CompositeModel]': ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'],
'u1[UnitModule]': ['u1.l1.weight', 'u1.l1.bias', 'u1.seq.1.weight', 'u1.seq.1.bias', 'u1.l2.weight', 'u1.l2.bias'],
'u2[UnitModule]': ['u2.l1.weight', 'u2.l1.bias', 'u2.seq.1.weight', 'u2.seq.1.bias', 'u2.l2.weight', 'u2.l2.bias']
}
All FQNs are prefixed starting from ``model``.
Args:
model (torch.nn.Module): Root module (which may or may not be passed to
composable `fully_shard()`).
"""
def module_fn(
module, prefix, tree_level, sharded_tree_info, sharded_module_name_to_fqns
):
num_spaces = tree_level * 4
trimed_prefix = (
prefix[:-1] if (len(prefix) > 0 and prefix[-1] == ".") else prefix
)
prefixed_module_name = trimed_prefix + "[" + module.__class__.__name__ + "]"
printed_prefixed_module_name = " " * num_spaces + prefixed_module_name
state = _get_module_fsdp_state(module)
if state is None:
sharded_tree_info[0] += printed_prefixed_module_name + "\n"
return
handle = state._fully_sharded_module_to_handle.get(module, None)
if handle:
sharded_tree_info[0] += (
printed_prefixed_module_name + " FULLY SHARDED" + "\n"
)
else:
sharded_tree_info[0] += printed_prefixed_module_name + "\n"
if handle:
param = handle.flat_param
assert isinstance(param, flat_param_file.FlatParameter)
global_fqns = [
clean_tensor_name(prefix + name) for name in param._fqns
] # prefixed from the top level `model` (i.e. including `prefix`)
if prefixed_module_name in sharded_module_name_to_fqns:
sharded_module_name_to_fqns[prefixed_module_name].extend(global_fqns)
else:
sharded_module_name_to_fqns[prefixed_module_name] = global_fqns
def return_fn(sharded_tree_info, sharded_module_name_to_fqns):
return sharded_tree_info[0], sharded_module_name_to_fqns
# Use List to mutate its value in place while running the recursive functions
sharded_tree_info: List[str] = [
"",
]
sharded_module_name_to_fqns: Dict[str, List[str]] = {}
return _apply_to_modules(
model,
module_fn,
return_fn,
[key for key, _ in model.named_parameters()],
sharded_tree_info,
sharded_module_name_to_fqns,
)