Xuan Zhang
|
e027403dea
|
ILP for Auto SAC (Selective Activation Checkpointing) (#137908)
This PR presents a mixed integer linear programming (MILP) formulation that can be utilized to determine, under a memory budget, which modules to apply activation checkpointing (AC) and the amount of activation memory that should be discarded for each module. The MILP uses information collected from MemTracker, Runtime Estimator, and SAC Estimator, introduced in these PRs:
* https://github.com/pytorch/pytorch/pull/124688
* https://github.com/pytorch/pytorch/pull/134243
* https://github.com/pytorch/pytorch/pull/135208
End-to-end example and its sample output:
```
import copy
from typing import Tuple
import torch
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._tools.ilp_utils import (
aggregate_stats,
get_peak_memory_runtime_baseline,
parse_module_info,
)
from torch.distributed._tools.mem_tracker import _ModState, MemTracker
from torch.distributed._tools.runtime_estimator import RuntimeEstimator
from torch.distributed._tools.sac_estimator import SACEstimator
from torch.distributed._tools.sac_ilp import sac_milp
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
)
def _init_model_input_optimizer() -> Tuple[
torch.nn.Module, torch.optim.Optimizer, torch.Tensor
]:
bsz = 8
model_args = ModelArgs(
n_layers=4,
n_heads=12,
vocab_size=8192,
max_seq_len=1024,
dim=768,
dropout_p=0.1,
)
with torch.device(torch.cuda.current_device()):
model = Transformer(model_args)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
inp = torch.randint(
0,
model_args.vocab_size,
(bsz, model_args.max_seq_len),
device=torch.cuda.current_device(),
)
return (model, optimizer, inp)
def _run_and_get_mem_tracker(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
inp: torch.Tensor,
) -> MemTracker:
mem_tracker = MemTracker()
mem_tracker.track_external(model, optimizer)
with mem_tracker as mt:
for iter_idx in range(2): # running twice to initialize optimizer
output = model(inp)
output.sum().backward()
if iter_idx == 1:
last_snapshot = mt.get_tracker_snapshot("current")
optimizer.step()
optimizer.zero_grad()
if iter_idx == 0:
mt.reset_mod_stats()
assert last_snapshot is not None
for mod_stats in mem_tracker.memory_tracking.values():
if _ModState.POST_BW not in mod_stats.snapshots.keys():
mod_stats.snapshots.setdefault(_ModState.POST_BW, []).append(
copy.deepcopy(last_snapshot)
)
return mem_tracker
def _run_and_get_runtime_estimator(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
inp: torch.Tensor,
) -> RuntimeEstimator:
def _run_one_step() -> None:
output = model(inp)
output.sum().backward()
optimizer.step()
optimizer.zero_grad()
# Initializing optimizer states and warm-up
_run_one_step()
runtime_estimator = RuntimeEstimator()
with runtime_estimator(estimate_mode_type="operator-level-cost-model"):
_run_one_step() # We use only one iteration for estimation
return runtime_estimator
def _run_and_get_sac_estimator(
model: torch.nn.Module,
inp: torch.Tensor,
) -> SACEstimator:
sac_estimator = SACEstimator()
with sac_estimator(estimate_mode_type="operator-level-cost-model"):
loss = model(inp).sum()
loss.backward()
return sac_estimator
def main():
with FakeTensorMode():
model, optimizer, inp = _init_model_input_optimizer()
mem_tracker = _run_and_get_mem_tracker(model, optimizer, inp)
runtime_estimator = _run_and_get_runtime_estimator(model, optimizer, inp)
sac_estimator = _run_and_get_sac_estimator(model, inp)
mod_info = aggregate_stats(
model,
mem_tracker,
runtime_estimator,
sac_estimator,
torch.device(torch.cuda.current_device()),
)
g = parse_module_info(mod_info)
peak_mem, compute_time = get_peak_memory_runtime_baseline(g)
print("=== WITHOUT AC ===")
print(f"peak_mem: {round(peak_mem / 2**30, 2)} GiB")
print(f"compute_time: {round(compute_time, 2)} ms")
ac_decisions, recomputation_time, peak_mem = sac_milp(g, memory_budget=1.75)
print("=== WITH AC ===")
print(f"ac_decisions: {ac_decisions}")
print(f"peak_mem: {round(peak_mem / 2**30, 2)} GiB")
print(f"recomputation_time: {recomputation_time} ms")
if __name__ == "__main__":
main()
```
```
=== WITHOUT AC ===
peak_mem: 2.41 GiB
compute_time: 97.97 ms
=== WITH AC ===
ac_decisions: {'Transformer.layers.0': 0.5232, 'Transformer.layers.1': 0.5232, 'Transformer.layers.2': 0.6849, 'Transformer.layers.3': 0.5232}
peak_mem: 1.75 GiB
recomputation_time: 5.92 ms
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137908
Approved by: https://github.com/weifengpy
|
2024-10-18 12:45:37 +00:00 |
|