pytorch/torch/distributed/_tools/ilp_utils.py
Xuan Zhang e027403dea ILP for Auto SAC (Selective Activation Checkpointing) (#137908)
This PR presents a mixed integer linear programming (MILP) formulation that can be utilized to determine, under a memory budget, which modules to apply activation checkpointing (AC) and the amount of activation memory that should be discarded for each module. The MILP uses information collected from MemTracker, Runtime Estimator, and SAC Estimator, introduced in these PRs:
* https://github.com/pytorch/pytorch/pull/124688
* https://github.com/pytorch/pytorch/pull/134243
* https://github.com/pytorch/pytorch/pull/135208

End-to-end example and its sample output:

```
import copy
from typing import Tuple

import torch
from torch._subclasses.fake_tensor import FakeTensorMode

from torch.distributed._tools.ilp_utils import (
    aggregate_stats,
    get_peak_memory_runtime_baseline,
    parse_module_info,
)
from torch.distributed._tools.mem_tracker import _ModState, MemTracker
from torch.distributed._tools.runtime_estimator import RuntimeEstimator
from torch.distributed._tools.sac_estimator import SACEstimator
from torch.distributed._tools.sac_ilp import sac_milp
from torch.testing._internal.distributed._tensor.common_dtensor import (
    ModelArgs,
    Transformer,
)

def _init_model_input_optimizer() -> Tuple[
    torch.nn.Module, torch.optim.Optimizer, torch.Tensor
]:
    bsz = 8
    model_args = ModelArgs(
        n_layers=4,
        n_heads=12,
        vocab_size=8192,
        max_seq_len=1024,
        dim=768,
        dropout_p=0.1,
    )
    with torch.device(torch.cuda.current_device()):
        model = Transformer(model_args)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
    inp = torch.randint(
        0,
        model_args.vocab_size,
        (bsz, model_args.max_seq_len),
        device=torch.cuda.current_device(),
    )
    return (model, optimizer, inp)

def _run_and_get_mem_tracker(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    inp: torch.Tensor,
) -> MemTracker:
    mem_tracker = MemTracker()
    mem_tracker.track_external(model, optimizer)
    with mem_tracker as mt:
        for iter_idx in range(2):  # running twice to initialize optimizer
            output = model(inp)
            output.sum().backward()
            if iter_idx == 1:
                last_snapshot = mt.get_tracker_snapshot("current")
            optimizer.step()
            optimizer.zero_grad()
            if iter_idx == 0:
                mt.reset_mod_stats()
    assert last_snapshot is not None
    for mod_stats in mem_tracker.memory_tracking.values():
        if _ModState.POST_BW not in mod_stats.snapshots.keys():
            mod_stats.snapshots.setdefault(_ModState.POST_BW, []).append(
                copy.deepcopy(last_snapshot)
            )
    return mem_tracker

def _run_and_get_runtime_estimator(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    inp: torch.Tensor,
) -> RuntimeEstimator:
    def _run_one_step() -> None:
        output = model(inp)
        output.sum().backward()
        optimizer.step()
        optimizer.zero_grad()

    # Initializing optimizer states and warm-up
    _run_one_step()

    runtime_estimator = RuntimeEstimator()
    with runtime_estimator(estimate_mode_type="operator-level-cost-model"):
        _run_one_step()  # We use only one iteration for estimation
    return runtime_estimator

def _run_and_get_sac_estimator(
    model: torch.nn.Module,
    inp: torch.Tensor,
) -> SACEstimator:
    sac_estimator = SACEstimator()
    with sac_estimator(estimate_mode_type="operator-level-cost-model"):
        loss = model(inp).sum()
    loss.backward()
    return sac_estimator

def main():
    with FakeTensorMode():
        model, optimizer, inp = _init_model_input_optimizer()
        mem_tracker = _run_and_get_mem_tracker(model, optimizer, inp)
        runtime_estimator = _run_and_get_runtime_estimator(model, optimizer, inp)
        sac_estimator = _run_and_get_sac_estimator(model, inp)
        mod_info = aggregate_stats(
            model,
            mem_tracker,
            runtime_estimator,
            sac_estimator,
            torch.device(torch.cuda.current_device()),
        )
        g = parse_module_info(mod_info)

        peak_mem, compute_time = get_peak_memory_runtime_baseline(g)
        print("=== WITHOUT AC ===")
        print(f"peak_mem: {round(peak_mem / 2**30, 2)} GiB")
        print(f"compute_time: {round(compute_time, 2)} ms")

        ac_decisions, recomputation_time, peak_mem = sac_milp(g, memory_budget=1.75)
        print("=== WITH AC ===")
        print(f"ac_decisions: {ac_decisions}")
        print(f"peak_mem: {round(peak_mem / 2**30, 2)} GiB")
        print(f"recomputation_time: {recomputation_time} ms")

if __name__ == "__main__":
    main()
```

```
=== WITHOUT AC ===
peak_mem: 2.41 GiB
compute_time: 97.97 ms
=== WITH AC ===
ac_decisions: {'Transformer.layers.0': 0.5232, 'Transformer.layers.1': 0.5232, 'Transformer.layers.2': 0.6849, 'Transformer.layers.3': 0.5232}
peak_mem: 1.75 GiB
recomputation_time: 5.92 ms
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137908
Approved by: https://github.com/weifengpy
2024-10-18 12:45:37 +00:00

292 lines
9.9 KiB
Python

import copy
from typing import cast, Dict, List, OrderedDict, Tuple, TypedDict
import numpy as np
import torch
from torch.distributed._tools.mem_tracker import (
_MemRefType,
_ModMemStats,
_ModState,
MemTracker,
)
from torch.distributed._tools.runtime_estimator import RuntimeEstimator
from torch.distributed._tools.sac_estimator import SACEstimator, SACTradeOffStats
class ModOrder(TypedDict):
fw_pre_order: List[str]
bw_pre_order: List[str]
fw_post_order: List[str]
bw_post_order: List[str]
class ModRuntime(TypedDict):
fw: float
bw: float
class ModStats(TypedDict):
fqn: str
# per-module params
param_per_module: int
# per-module grads
grad_per_module: int
# total accumulated gradients up to and including this module
grad_total: int
# per module fw activation size (excluding input and output)
act_fw_per_module: int
# per module bw activation size during peak_bw
act_bw_per_module: int
# per module activation grad size during peak_bw
act_grad_per_module: int
# total activation size up to but excluding the current module
# includes input of the current module (i.e., output of previous module)
act_total: int
# Inputs to the module
input_per_module: int
# Outputs of the module
output_per_module: int
# Total fw run-time of the module
fw_runtime_per_module: float
# Total bw run-time of the module
bw_runtime_per_module: float
# Is this module a leaf module
is_leaf: bool
# Total ac run-time of the module
sac_runtime: float
# Total ac_memory for the module
sac_memory: int
# Number of piecewise-linear functions used for approximating ac tradeoff curve
n_segments: int
# Slopes of the of piecewise-linear functions
slopes: List[float]
# Intercepts of the of piecewise-linear functions
intercepts: List[float]
# X breakpoints of the of piecewise-linear functions
breakpoints: List[float]
# Original trade-off curves
tradeoff_curve: OrderedDict[float, float]
class ModuleInfo(TypedDict):
mod_order: ModOrder
mod_stats: List[ModStats]
def aggregate_stats(
model: torch.nn.Module,
mem_tracker: MemTracker,
runtime_estimator: RuntimeEstimator,
sac_estimator: SACEstimator,
dev: torch.device,
) -> ModuleInfo:
"""
Collect modulewise stats for a given model, including memory, runtime, and AC tradeoff stats.
Args:
model: nn.Module object
runtime_estimator: RuntimeEstimator object with runtime stats
mem_tracker: MemTracker object with memory stats
sac_estimator: SACEstimator object with AC tradeoff stats
dev: device the model was run on (used to extract memory stats from MemTracker)
Returns:
ModuleInfo: A dictionary with module order and module stats.
"""
# Memory stats
mod_mem_stats: Dict[torch.nn.Module, _ModMemStats] = dict(
copy.deepcopy(mem_tracker.memory_tracking)
)
# Runtime stats
mod_runtime_stats: Dict[str, ModRuntime] = {
fqn: {"fw": v["fw"], "bw": v["bw"]}
for fqn, v in runtime_estimator.mod_runtimes.items()
}
# Module order
mod_order: ModOrder = {
"fw_pre_order": list(runtime_estimator.mod_fw_pre_order),
"bw_pre_order": list(runtime_estimator.mod_bw_pre_order),
"fw_post_order": list(runtime_estimator.mod_fw_post_order),
"bw_post_order": list(runtime_estimator.mod_bw_post_order),
}
# Selective Activation Checkpointing stats
sac_estimator.pwlf_sac_tradeoff_curve()
mod_sac_tradeoff_stats: Dict[str, SACTradeOffStats] = copy.deepcopy(
sac_estimator.sac_mod_tradeoff_stats
)
module_info: ModuleInfo = {
"mod_order": mod_order,
"mod_stats": [],
}
for mod in model.modules():
if mod_mem_stat := mod_mem_stats.get(mod, None):
if tradeoff_stats := mod_sac_tradeoff_stats.get(mod_mem_stat.mod_fqn, None):
sac_runtime = tradeoff_stats.sac_runtime
sac_memory = tradeoff_stats.sac_memory
n_segments = tradeoff_stats.n_segments
slopes = tradeoff_stats.slopes
intercepts = tradeoff_stats.intercepts
breakpoints = tradeoff_stats.fit_breaks
tradeoff_curve = tradeoff_stats.tradeoff_curve
is_leaf = False
else:
sac_runtime = sac_memory = n_segments = 0
slopes = intercepts = breakpoints = []
tradeoff_curve: OrderedDict[float, float] = OrderedDict() # type: ignore[no-redef]
is_leaf = True
mod_stat: ModStats = {
"fqn": mod_mem_stat.mod_fqn,
"param_per_module": mod_mem_stat.parameter_mem,
"grad_per_module": mod_mem_stat.parameter_mem,
"grad_total": mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][
_MemRefType.GRAD
],
"act_fw_per_module": max(
0,
mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][_MemRefType.ACT]
- mod_mem_stat.snapshots[_ModState.PRE_FW][-1][dev][_MemRefType.ACT]
- mod_mem_stat.output_mem,
),
"act_bw_per_module": max(
0,
mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.ACT],
),
"act_grad_per_module": (
mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.TEMP]
- mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][
_MemRefType.TEMP
]
),
"act_total": mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][
_MemRefType.ACT
],
"input_per_module": mod_mem_stat.input_mem,
"output_per_module": mod_mem_stat.output_mem,
"fw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["fw"],
"bw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["bw"],
"is_leaf": is_leaf,
"sac_runtime": sac_runtime,
"sac_memory": sac_memory,
"n_segments": n_segments,
"slopes": slopes,
"intercepts": intercepts,
"breakpoints": breakpoints,
"tradeoff_curve": tradeoff_curve,
}
module_info["mod_stats"].append(mod_stat)
return module_info
class Node(ModStats):
index: int # index according to forward pre-order
pos_fw_post_order: int # index according to forward post-order
class Graph:
def __init__(self, n: int) -> None:
self.nodes: List[Node] = []
self.name2node: Dict[str, Node] = {}
self.ad_matrix = np.zeros((n, n))
self.fw_post_order: List[str] = []
def add_node(self, node: Node) -> None:
self.nodes.append(node)
self.name2node[node["fqn"]] = node
def parse_module_info(module_info: ModuleInfo) -> Graph:
"""
Parse module info and create a graph (tree) of modules. The graph will be
used by MILP solver to find optimal SAC and/or FSDP configurations.
"""
mod_stats = module_info["mod_stats"]
fw_pre_order = module_info["mod_order"]["fw_pre_order"]
# assertion and number of nodes
assert len(mod_stats) == len(fw_pre_order)
n_nodes = len(mod_stats)
# create graph
g = Graph(n_nodes)
g.fw_post_order = module_info["mod_order"]["fw_post_order"]
# sort the modules by pre-order and add them to the graph
module_info["mod_stats"] = sorted(
mod_stats, key=lambda x: fw_pre_order.index(x["fqn"])
)
for i, one_mod_stats in enumerate(mod_stats):
node: Node = cast(Node, one_mod_stats)
node["index"] = i
node["pos_fw_post_order"] = g.fw_post_order.index(node["fqn"])
g.add_node(node)
# set up ancestor-descendant matrix
for i in range(n_nodes):
for j in range(i, n_nodes):
if is_self_or_submodule(g.nodes[j]["fqn"], g.nodes[i]["fqn"]):
g.ad_matrix[i][j] = 1
else:
break
return g
def is_self_or_submodule(name_descendant: str, name_ancestor: str) -> bool:
"""
check if name_descendant is a submodule of name_ancestor, or if they are the same
"""
return name_descendant == name_ancestor or name_ancestor + "." in name_descendant
def is_submodule(name_descendant: str, name_ancestor: str) -> bool:
"""
if name_descendant is a submodule of name_ancestor, but not the same
"""
return name_ancestor + "." in name_descendant
def display_bytes(b: int, unit: str = "MiB") -> str:
"""
return a string that represent the number of bytes in a desired unit
"""
if unit == "KiB":
return f"{b/2**10:.2f} KiB"
if unit == "MiB":
return f"{b/2**20:.2f} MiB"
if unit == "GiB":
return f"{b/2**30:.2f} GiB"
return f"{b:.2f} bytes"
def get_peak_memory_runtime_baseline(graph: Graph) -> Tuple[int, float]:
"""
Get the baseline peak memory and runtime.
Baseline here means there is no FSDP or AC.
Memory includes the parameters, gradients, activations, and activation gradients.
Memory does not include e.g., optimizer states, embedding tables, etc.
Returns:
int: peak memory in bytes
float: compute time in ms
"""
P_1 = graph.nodes[0]["param_per_module"]
num_nodes = len(graph.nodes)
peak_mem = 0
for i in range(num_nodes):
TG_i = graph.nodes[i]["grad_total"]
AG_i = graph.nodes[i]["act_grad_per_module"]
TA_i = graph.nodes[i]["act_total"]
peak_mem = max(peak_mem, P_1 + TG_i + AG_i + TA_i)
compute_time = (
graph.nodes[0]["fw_runtime_per_module"]
+ graph.nodes[0]["bw_runtime_per_module"]
)
return (peak_mem, compute_time)