mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR presents a mixed integer linear programming (MILP) formulation that can be utilized to determine, under a memory budget, which modules to apply activation checkpointing (AC) and the amount of activation memory that should be discarded for each module. The MILP uses information collected from MemTracker, Runtime Estimator, and SAC Estimator, introduced in these PRs: * https://github.com/pytorch/pytorch/pull/124688 * https://github.com/pytorch/pytorch/pull/134243 * https://github.com/pytorch/pytorch/pull/135208 End-to-end example and its sample output: ``` import copy from typing import Tuple import torch from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed._tools.ilp_utils import ( aggregate_stats, get_peak_memory_runtime_baseline, parse_module_info, ) from torch.distributed._tools.mem_tracker import _ModState, MemTracker from torch.distributed._tools.runtime_estimator import RuntimeEstimator from torch.distributed._tools.sac_estimator import SACEstimator from torch.distributed._tools.sac_ilp import sac_milp from torch.testing._internal.distributed._tensor.common_dtensor import ( ModelArgs, Transformer, ) def _init_model_input_optimizer() -> Tuple[ torch.nn.Module, torch.optim.Optimizer, torch.Tensor ]: bsz = 8 model_args = ModelArgs( n_layers=4, n_heads=12, vocab_size=8192, max_seq_len=1024, dim=768, dropout_p=0.1, ) with torch.device(torch.cuda.current_device()): model = Transformer(model_args) optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True) inp = torch.randint( 0, model_args.vocab_size, (bsz, model_args.max_seq_len), device=torch.cuda.current_device(), ) return (model, optimizer, inp) def _run_and_get_mem_tracker( model: torch.nn.Module, optimizer: torch.optim.Optimizer, inp: torch.Tensor, ) -> MemTracker: mem_tracker = MemTracker() mem_tracker.track_external(model, optimizer) with mem_tracker as mt: for iter_idx in range(2): # running twice to initialize optimizer output = model(inp) output.sum().backward() if iter_idx == 1: last_snapshot = mt.get_tracker_snapshot("current") optimizer.step() optimizer.zero_grad() if iter_idx == 0: mt.reset_mod_stats() assert last_snapshot is not None for mod_stats in mem_tracker.memory_tracking.values(): if _ModState.POST_BW not in mod_stats.snapshots.keys(): mod_stats.snapshots.setdefault(_ModState.POST_BW, []).append( copy.deepcopy(last_snapshot) ) return mem_tracker def _run_and_get_runtime_estimator( model: torch.nn.Module, optimizer: torch.optim.Optimizer, inp: torch.Tensor, ) -> RuntimeEstimator: def _run_one_step() -> None: output = model(inp) output.sum().backward() optimizer.step() optimizer.zero_grad() # Initializing optimizer states and warm-up _run_one_step() runtime_estimator = RuntimeEstimator() with runtime_estimator(estimate_mode_type="operator-level-cost-model"): _run_one_step() # We use only one iteration for estimation return runtime_estimator def _run_and_get_sac_estimator( model: torch.nn.Module, inp: torch.Tensor, ) -> SACEstimator: sac_estimator = SACEstimator() with sac_estimator(estimate_mode_type="operator-level-cost-model"): loss = model(inp).sum() loss.backward() return sac_estimator def main(): with FakeTensorMode(): model, optimizer, inp = _init_model_input_optimizer() mem_tracker = _run_and_get_mem_tracker(model, optimizer, inp) runtime_estimator = _run_and_get_runtime_estimator(model, optimizer, inp) sac_estimator = _run_and_get_sac_estimator(model, inp) mod_info = aggregate_stats( model, mem_tracker, runtime_estimator, sac_estimator, torch.device(torch.cuda.current_device()), ) g = parse_module_info(mod_info) peak_mem, compute_time = get_peak_memory_runtime_baseline(g) print("=== WITHOUT AC ===") print(f"peak_mem: {round(peak_mem / 2**30, 2)} GiB") print(f"compute_time: {round(compute_time, 2)} ms") ac_decisions, recomputation_time, peak_mem = sac_milp(g, memory_budget=1.75) print("=== WITH AC ===") print(f"ac_decisions: {ac_decisions}") print(f"peak_mem: {round(peak_mem / 2**30, 2)} GiB") print(f"recomputation_time: {recomputation_time} ms") if __name__ == "__main__": main() ``` ``` === WITHOUT AC === peak_mem: 2.41 GiB compute_time: 97.97 ms === WITH AC === ac_decisions: {'Transformer.layers.0': 0.5232, 'Transformer.layers.1': 0.5232, 'Transformer.layers.2': 0.6849, 'Transformer.layers.3': 0.5232} peak_mem: 1.75 GiB recomputation_time: 5.92 ms ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/137908 Approved by: https://github.com/weifengpy
292 lines
9.9 KiB
Python
292 lines
9.9 KiB
Python
import copy
|
|
from typing import cast, Dict, List, OrderedDict, Tuple, TypedDict
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
from torch.distributed._tools.mem_tracker import (
|
|
_MemRefType,
|
|
_ModMemStats,
|
|
_ModState,
|
|
MemTracker,
|
|
)
|
|
from torch.distributed._tools.runtime_estimator import RuntimeEstimator
|
|
from torch.distributed._tools.sac_estimator import SACEstimator, SACTradeOffStats
|
|
|
|
|
|
class ModOrder(TypedDict):
|
|
fw_pre_order: List[str]
|
|
bw_pre_order: List[str]
|
|
fw_post_order: List[str]
|
|
bw_post_order: List[str]
|
|
|
|
|
|
class ModRuntime(TypedDict):
|
|
fw: float
|
|
bw: float
|
|
|
|
|
|
class ModStats(TypedDict):
|
|
fqn: str
|
|
# per-module params
|
|
param_per_module: int
|
|
# per-module grads
|
|
grad_per_module: int
|
|
# total accumulated gradients up to and including this module
|
|
grad_total: int
|
|
# per module fw activation size (excluding input and output)
|
|
act_fw_per_module: int
|
|
# per module bw activation size during peak_bw
|
|
act_bw_per_module: int
|
|
# per module activation grad size during peak_bw
|
|
act_grad_per_module: int
|
|
# total activation size up to but excluding the current module
|
|
# includes input of the current module (i.e., output of previous module)
|
|
act_total: int
|
|
# Inputs to the module
|
|
input_per_module: int
|
|
# Outputs of the module
|
|
output_per_module: int
|
|
# Total fw run-time of the module
|
|
fw_runtime_per_module: float
|
|
# Total bw run-time of the module
|
|
bw_runtime_per_module: float
|
|
# Is this module a leaf module
|
|
is_leaf: bool
|
|
# Total ac run-time of the module
|
|
sac_runtime: float
|
|
# Total ac_memory for the module
|
|
sac_memory: int
|
|
# Number of piecewise-linear functions used for approximating ac tradeoff curve
|
|
n_segments: int
|
|
# Slopes of the of piecewise-linear functions
|
|
slopes: List[float]
|
|
# Intercepts of the of piecewise-linear functions
|
|
intercepts: List[float]
|
|
# X breakpoints of the of piecewise-linear functions
|
|
breakpoints: List[float]
|
|
# Original trade-off curves
|
|
tradeoff_curve: OrderedDict[float, float]
|
|
|
|
|
|
class ModuleInfo(TypedDict):
|
|
mod_order: ModOrder
|
|
mod_stats: List[ModStats]
|
|
|
|
|
|
def aggregate_stats(
|
|
model: torch.nn.Module,
|
|
mem_tracker: MemTracker,
|
|
runtime_estimator: RuntimeEstimator,
|
|
sac_estimator: SACEstimator,
|
|
dev: torch.device,
|
|
) -> ModuleInfo:
|
|
"""
|
|
Collect modulewise stats for a given model, including memory, runtime, and AC tradeoff stats.
|
|
|
|
Args:
|
|
model: nn.Module object
|
|
runtime_estimator: RuntimeEstimator object with runtime stats
|
|
mem_tracker: MemTracker object with memory stats
|
|
sac_estimator: SACEstimator object with AC tradeoff stats
|
|
dev: device the model was run on (used to extract memory stats from MemTracker)
|
|
|
|
Returns:
|
|
ModuleInfo: A dictionary with module order and module stats.
|
|
"""
|
|
|
|
# Memory stats
|
|
mod_mem_stats: Dict[torch.nn.Module, _ModMemStats] = dict(
|
|
copy.deepcopy(mem_tracker.memory_tracking)
|
|
)
|
|
|
|
# Runtime stats
|
|
mod_runtime_stats: Dict[str, ModRuntime] = {
|
|
fqn: {"fw": v["fw"], "bw": v["bw"]}
|
|
for fqn, v in runtime_estimator.mod_runtimes.items()
|
|
}
|
|
|
|
# Module order
|
|
mod_order: ModOrder = {
|
|
"fw_pre_order": list(runtime_estimator.mod_fw_pre_order),
|
|
"bw_pre_order": list(runtime_estimator.mod_bw_pre_order),
|
|
"fw_post_order": list(runtime_estimator.mod_fw_post_order),
|
|
"bw_post_order": list(runtime_estimator.mod_bw_post_order),
|
|
}
|
|
|
|
# Selective Activation Checkpointing stats
|
|
sac_estimator.pwlf_sac_tradeoff_curve()
|
|
mod_sac_tradeoff_stats: Dict[str, SACTradeOffStats] = copy.deepcopy(
|
|
sac_estimator.sac_mod_tradeoff_stats
|
|
)
|
|
|
|
module_info: ModuleInfo = {
|
|
"mod_order": mod_order,
|
|
"mod_stats": [],
|
|
}
|
|
|
|
for mod in model.modules():
|
|
if mod_mem_stat := mod_mem_stats.get(mod, None):
|
|
if tradeoff_stats := mod_sac_tradeoff_stats.get(mod_mem_stat.mod_fqn, None):
|
|
sac_runtime = tradeoff_stats.sac_runtime
|
|
sac_memory = tradeoff_stats.sac_memory
|
|
n_segments = tradeoff_stats.n_segments
|
|
slopes = tradeoff_stats.slopes
|
|
intercepts = tradeoff_stats.intercepts
|
|
breakpoints = tradeoff_stats.fit_breaks
|
|
tradeoff_curve = tradeoff_stats.tradeoff_curve
|
|
is_leaf = False
|
|
else:
|
|
sac_runtime = sac_memory = n_segments = 0
|
|
slopes = intercepts = breakpoints = []
|
|
tradeoff_curve: OrderedDict[float, float] = OrderedDict() # type: ignore[no-redef]
|
|
is_leaf = True
|
|
mod_stat: ModStats = {
|
|
"fqn": mod_mem_stat.mod_fqn,
|
|
"param_per_module": mod_mem_stat.parameter_mem,
|
|
"grad_per_module": mod_mem_stat.parameter_mem,
|
|
"grad_total": mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][
|
|
_MemRefType.GRAD
|
|
],
|
|
"act_fw_per_module": max(
|
|
0,
|
|
mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][_MemRefType.ACT]
|
|
- mod_mem_stat.snapshots[_ModState.PRE_FW][-1][dev][_MemRefType.ACT]
|
|
- mod_mem_stat.output_mem,
|
|
),
|
|
"act_bw_per_module": max(
|
|
0,
|
|
mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.ACT],
|
|
),
|
|
"act_grad_per_module": (
|
|
mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.TEMP]
|
|
- mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][
|
|
_MemRefType.TEMP
|
|
]
|
|
),
|
|
"act_total": mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][
|
|
_MemRefType.ACT
|
|
],
|
|
"input_per_module": mod_mem_stat.input_mem,
|
|
"output_per_module": mod_mem_stat.output_mem,
|
|
"fw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["fw"],
|
|
"bw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["bw"],
|
|
"is_leaf": is_leaf,
|
|
"sac_runtime": sac_runtime,
|
|
"sac_memory": sac_memory,
|
|
"n_segments": n_segments,
|
|
"slopes": slopes,
|
|
"intercepts": intercepts,
|
|
"breakpoints": breakpoints,
|
|
"tradeoff_curve": tradeoff_curve,
|
|
}
|
|
module_info["mod_stats"].append(mod_stat)
|
|
|
|
return module_info
|
|
|
|
|
|
class Node(ModStats):
|
|
index: int # index according to forward pre-order
|
|
pos_fw_post_order: int # index according to forward post-order
|
|
|
|
|
|
class Graph:
|
|
def __init__(self, n: int) -> None:
|
|
self.nodes: List[Node] = []
|
|
self.name2node: Dict[str, Node] = {}
|
|
self.ad_matrix = np.zeros((n, n))
|
|
self.fw_post_order: List[str] = []
|
|
|
|
def add_node(self, node: Node) -> None:
|
|
self.nodes.append(node)
|
|
self.name2node[node["fqn"]] = node
|
|
|
|
|
|
def parse_module_info(module_info: ModuleInfo) -> Graph:
|
|
"""
|
|
Parse module info and create a graph (tree) of modules. The graph will be
|
|
used by MILP solver to find optimal SAC and/or FSDP configurations.
|
|
"""
|
|
mod_stats = module_info["mod_stats"]
|
|
fw_pre_order = module_info["mod_order"]["fw_pre_order"]
|
|
# assertion and number of nodes
|
|
assert len(mod_stats) == len(fw_pre_order)
|
|
n_nodes = len(mod_stats)
|
|
|
|
# create graph
|
|
g = Graph(n_nodes)
|
|
g.fw_post_order = module_info["mod_order"]["fw_post_order"]
|
|
|
|
# sort the modules by pre-order and add them to the graph
|
|
module_info["mod_stats"] = sorted(
|
|
mod_stats, key=lambda x: fw_pre_order.index(x["fqn"])
|
|
)
|
|
for i, one_mod_stats in enumerate(mod_stats):
|
|
node: Node = cast(Node, one_mod_stats)
|
|
node["index"] = i
|
|
node["pos_fw_post_order"] = g.fw_post_order.index(node["fqn"])
|
|
g.add_node(node)
|
|
|
|
# set up ancestor-descendant matrix
|
|
for i in range(n_nodes):
|
|
for j in range(i, n_nodes):
|
|
if is_self_or_submodule(g.nodes[j]["fqn"], g.nodes[i]["fqn"]):
|
|
g.ad_matrix[i][j] = 1
|
|
else:
|
|
break
|
|
|
|
return g
|
|
|
|
|
|
def is_self_or_submodule(name_descendant: str, name_ancestor: str) -> bool:
|
|
"""
|
|
check if name_descendant is a submodule of name_ancestor, or if they are the same
|
|
"""
|
|
return name_descendant == name_ancestor or name_ancestor + "." in name_descendant
|
|
|
|
|
|
def is_submodule(name_descendant: str, name_ancestor: str) -> bool:
|
|
"""
|
|
if name_descendant is a submodule of name_ancestor, but not the same
|
|
"""
|
|
return name_ancestor + "." in name_descendant
|
|
|
|
|
|
def display_bytes(b: int, unit: str = "MiB") -> str:
|
|
"""
|
|
return a string that represent the number of bytes in a desired unit
|
|
"""
|
|
if unit == "KiB":
|
|
return f"{b/2**10:.2f} KiB"
|
|
if unit == "MiB":
|
|
return f"{b/2**20:.2f} MiB"
|
|
if unit == "GiB":
|
|
return f"{b/2**30:.2f} GiB"
|
|
return f"{b:.2f} bytes"
|
|
|
|
|
|
def get_peak_memory_runtime_baseline(graph: Graph) -> Tuple[int, float]:
|
|
"""
|
|
Get the baseline peak memory and runtime.
|
|
Baseline here means there is no FSDP or AC.
|
|
Memory includes the parameters, gradients, activations, and activation gradients.
|
|
Memory does not include e.g., optimizer states, embedding tables, etc.
|
|
|
|
Returns:
|
|
int: peak memory in bytes
|
|
float: compute time in ms
|
|
"""
|
|
P_1 = graph.nodes[0]["param_per_module"]
|
|
num_nodes = len(graph.nodes)
|
|
peak_mem = 0
|
|
for i in range(num_nodes):
|
|
TG_i = graph.nodes[i]["grad_total"]
|
|
AG_i = graph.nodes[i]["act_grad_per_module"]
|
|
TA_i = graph.nodes[i]["act_total"]
|
|
peak_mem = max(peak_mem, P_1 + TG_i + AG_i + TA_i)
|
|
compute_time = (
|
|
graph.nodes[0]["fw_runtime_per_module"]
|
|
+ graph.nodes[0]["bw_runtime_per_module"]
|
|
)
|
|
return (peak_mem, compute_time)
|