mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR adds SimpleProfiler for FSDP state_dict/load_state_dict logging purpose. SimpleProfiler use class variables to record profiling results and it does everything in the Python which can be slow. So it is only suitable for logging slow actions such as initialization and state_dict/load_state_dict. This PR uses SimpleProfiler to log some critical/slow paths of the model and optimizer state_dict/load_state_dict. Differential Revision: [D48774406](https://our.internmc.facebook.com/intern/diff/D48774406/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/108290 Approved by: https://github.com/wz337
152 lines
5.3 KiB
Python
152 lines
5.3 KiB
Python
import logging
|
|
import time
|
|
from collections import defaultdict
|
|
from contextlib import contextmanager
|
|
from enum import Enum
|
|
from typing import Dict, Iterator, List, Set, Tuple
|
|
|
|
import torch
|
|
import torch.distributed.fsdp.flat_param as flat_param_file
|
|
from torch.distributed.fsdp._common_utils import (
|
|
_apply_to_modules,
|
|
_get_module_fsdp_state,
|
|
clean_tensor_name,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SimpleProfiler:
|
|
class Type(str, Enum):
|
|
ALL = "all"
|
|
ALLGATHER = "all_gather"
|
|
ALLGATHER_OBJ = "all_gather_object"
|
|
RESHARDING = "resharding"
|
|
H2D = "H2D"
|
|
D2H = "D2H"
|
|
|
|
results: Dict[str, float] = defaultdict(float)
|
|
profiling: Set[str] = set()
|
|
|
|
@classmethod
|
|
def reset(cls) -> None:
|
|
cls.results.clear()
|
|
cls.profiling.clear()
|
|
|
|
@classmethod
|
|
@contextmanager
|
|
def profile(cls, profile_type: str) -> Iterator[None]:
|
|
assert profile_type not in cls.profiling, (
|
|
f"{profile_type} is already being profiled. "
|
|
"SimpleProfiler does not support profiling multiple instances at "
|
|
"the same time. "
|
|
)
|
|
|
|
cls.profiling.add(profile_type)
|
|
begin = time.monotonic()
|
|
try:
|
|
yield
|
|
finally:
|
|
end = time.monotonic()
|
|
cls.results[profile_type] += end - begin
|
|
cls.profiling.remove(profile_type)
|
|
|
|
@classmethod
|
|
def dump_and_reset(cls, msg: str) -> None:
|
|
logger.warning("%s %s", msg, str(cls.results))
|
|
cls.reset()
|
|
|
|
|
|
def _get_sharded_module_tree_with_module_name_to_fqns(
|
|
model: torch.nn.Module,
|
|
) -> Tuple[str, Dict[str, List[str]]]:
|
|
"""
|
|
It is used for composable fully_shard() code path, it returns
|
|
1. sharded module tree info: each line reprents a submodule name that contats the
|
|
submodule's FQN and its submodule class name, if the submodule is sharded by `fully_shard`,
|
|
the submodule name will add a postfix with ' FULLY SHARDED'. Each increased tree
|
|
level adds 4 spaces before the printed name. A printed sharded module tree info for a toy model
|
|
is like this:
|
|
[CompositeModel] FULLY SHARDED
|
|
l1[Linear]
|
|
u1[UnitModule] FULLY SHARDED
|
|
u1.l1[Linear]
|
|
u1.seq[Sequential]
|
|
u1.seq.0[ReLU]
|
|
u1.seq.1[Linear]
|
|
u1.seq.2[ReLU]
|
|
u1.l2[Linear]
|
|
u2[UnitModule] FULLY SHARDED
|
|
u2.l1[Linear]
|
|
u2.seq[Sequential]
|
|
u2.seq.0[ReLU]
|
|
u2.seq.1[Linear]
|
|
u2.seq.2[ReLU]
|
|
u2.l2[Linear]
|
|
l2[Linear]
|
|
2. a dict mapping from the concated module FQN and class name to a list of its managed
|
|
original parameters' FQNs. An example of the dict for the above toy sharded model is like this:
|
|
{'[CompositeModel]': ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'],
|
|
'u1[UnitModule]': ['u1.l1.weight', 'u1.l1.bias', 'u1.seq.1.weight', 'u1.seq.1.bias', 'u1.l2.weight', 'u1.l2.bias'],
|
|
'u2[UnitModule]': ['u2.l1.weight', 'u2.l1.bias', 'u2.seq.1.weight', 'u2.seq.1.bias', 'u2.l2.weight', 'u2.l2.bias']
|
|
}
|
|
All FQNs are prefixed starting from ``model``.
|
|
|
|
Args:
|
|
model (torch.nn.Module): Root module (which may or may not be passed to
|
|
composable `fully_shard()`).
|
|
"""
|
|
|
|
def module_fn(
|
|
module, prefix, tree_level, sharded_tree_info, sharded_module_name_to_fqns
|
|
):
|
|
num_spaces = tree_level * 4
|
|
trimed_prefix = (
|
|
prefix[:-1] if (len(prefix) > 0 and prefix[-1] == ".") else prefix
|
|
)
|
|
prefixed_module_name = trimed_prefix + "[" + module.__class__.__name__ + "]"
|
|
printed_prefixed_module_name = " " * num_spaces + prefixed_module_name
|
|
|
|
state = _get_module_fsdp_state(module)
|
|
if state is None:
|
|
sharded_tree_info[0] += printed_prefixed_module_name + "\n"
|
|
return
|
|
|
|
handle = state._fully_sharded_module_to_handle.get(module, None)
|
|
|
|
if handle:
|
|
sharded_tree_info[0] += (
|
|
printed_prefixed_module_name + " FULLY SHARDED" + "\n"
|
|
)
|
|
else:
|
|
sharded_tree_info[0] += printed_prefixed_module_name + "\n"
|
|
|
|
if handle:
|
|
param = handle.flat_param
|
|
assert isinstance(param, flat_param_file.FlatParameter)
|
|
global_fqns = [
|
|
clean_tensor_name(prefix + name) for name in param._fqns
|
|
] # prefixed from the top level `model` (i.e. including `prefix`)
|
|
|
|
if prefixed_module_name in sharded_module_name_to_fqns:
|
|
sharded_module_name_to_fqns[prefixed_module_name].extend(global_fqns)
|
|
else:
|
|
sharded_module_name_to_fqns[prefixed_module_name] = global_fqns
|
|
|
|
def return_fn(sharded_tree_info, sharded_module_name_to_fqns):
|
|
return sharded_tree_info[0], sharded_module_name_to_fqns
|
|
|
|
# Use List to mutate its value in place while running the recursive functions
|
|
sharded_tree_info: List[str] = [
|
|
"",
|
|
]
|
|
sharded_module_name_to_fqns: Dict[str, List[str]] = {}
|
|
return _apply_to_modules(
|
|
model,
|
|
module_fn,
|
|
return_fn,
|
|
[key for key, _ in model.named_parameters()],
|
|
sharded_tree_info,
|
|
sharded_module_name_to_fqns,
|
|
)
|