pytorch/test/distributed/fsdp/test_fsdp_backward_prefetch.py
2025-01-22 04:48:28 +00:00

221 lines
8.8 KiB
Python

# Owner(s): ["oncall: distributed"]
import sys
from unittest.mock import patch
import torch
import torch.nn as nn
from torch import distributed as dist
from torch.distributed.fsdp import BackwardPrefetch, FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_handle_fqns_from_root
from torch.distributed.fsdp._flat_param import HandleTrainingState
from torch.distributed.fsdp._runtime_utils import (
_get_handle_to_prefetch,
_get_training_state,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
device_type = torch.device(get_devtype())
NUM_ITERS = 2
DECODER_PARAM_FQNS = [
"decoder.layers.{index}.self_attn.in_proj_weight",
"decoder.layers.{index}.self_attn.in_proj_bias",
"decoder.layers.{index}.self_attn.out_proj.weight",
"decoder.layers.{index}.self_attn.out_proj.bias",
"decoder.layers.{index}.multihead_attn.in_proj_weight",
"decoder.layers.{index}.multihead_attn.in_proj_bias",
"decoder.layers.{index}.multihead_attn.out_proj.weight",
"decoder.layers.{index}.multihead_attn.out_proj.bias",
"decoder.layers.{index}.linear1.weight",
"decoder.layers.{index}.linear1.bias",
"decoder.layers.{index}.linear2.weight",
"decoder.layers.{index}.linear2.bias",
"decoder.layers.{index}.norm1.weight",
"decoder.layers.{index}.norm1.bias",
"decoder.layers.{index}.norm2.weight",
"decoder.layers.{index}.norm2.bias",
"decoder.layers.{index}.norm3.weight",
"decoder.layers.{index}.norm3.bias",
]
ENCODER_PARAM_FQNS = [
"encoder.layers.{index}.self_attn.in_proj_weight",
"encoder.layers.{index}.self_attn.in_proj_bias",
"encoder.layers.{index}.self_attn.out_proj.weight",
"encoder.layers.{index}.self_attn.out_proj.bias",
"encoder.layers.{index}.linear1.weight",
"encoder.layers.{index}.linear1.bias",
"encoder.layers.{index}.linear2.weight",
"encoder.layers.{index}.linear2.bias",
"encoder.layers.{index}.norm1.weight",
"encoder.layers.{index}.norm1.bias",
"encoder.layers.{index}.norm2.weight",
"encoder.layers.{index}.norm2.bias",
]
TOTAL_NUM_PREFETCH_FOR_PRE = 12
TOTAL_NUM_PREFETCH_FOR_POST = 11
ENCODER_BEGIN_INDEX_FOR_PRE = 6
ENCODER_BEGIN_INDEX_FOR_POST = 5
ENCODER_PREFETCH_NUM = 5
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestBackwardPrefetch(FSDPTest):
@property
def world_size(self):
return 2
def _dist_train(self, backward_prefetch=BackwardPrefetch.BACKWARD_PRE):
rank = self.rank
orig_get_handle_to_prefetch = _get_handle_to_prefetch
torch.manual_seed(0)
policy = ModuleWrapPolicy(
{nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
)
model = FSDP(
nn.Transformer(d_model=1024, nhead=8, device=device_type),
device_id=device_type.type,
auto_wrap_policy=policy,
use_orig_params=True,
backward_prefetch=backward_prefetch,
)
optim = torch.optim.SGD(model.parameters(), lr=1e-2)
# prepare input
torch.manual_seed(rank + 1)
src = torch.randn((10, 1, 1024), device=device_type)
tgt = torch.randn((20, 1, 1024), device=device_type)
# monkey patch
all_handle_fqns: list[list[str]] = []
def patched_get_handle_to_prefetch(*args, **kwargs):
handle = orig_get_handle_to_prefetch(*args, **kwargs)
self.assertEqual(
len(args), 2, "expect _get_handle_to_prefetch(state, current_handle)"
)
state = args[0]
current_handle = args[1]
training_state = _get_training_state(current_handle)
if (
training_state == HandleTrainingState.BACKWARD_PRE
and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
) or (
training_state == HandleTrainingState.BACKWARD_POST
and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST
):
nonlocal all_handle_fqns
# FQNs prefixed from the root module
# state._exec_order_data.param_to_fqn
fqns = _get_handle_fqns_from_root(state, handle)
all_handle_fqns.append(fqns)
return handle
# flat params from prefetch handle should match
# DECODER_PARAM_FQNS and ENCODER_PARAM_FQNS
with patch(
"torch.distributed.fsdp._runtime_utils._get_handle_to_prefetch",
patched_get_handle_to_prefetch,
):
for _ in range(NUM_ITERS):
optim.zero_grad()
loss = model(src, tgt).sum()
loss.backward()
optim.step()
if backward_prefetch is None:
self.assertEqual(len(all_handle_fqns), 0)
continue
elif backward_prefetch == BackwardPrefetch.BACKWARD_PRE:
# state._exec_order_data.handles_post_forward_order
# equals forward order
# encoder 0...5 -> decoder 0...5 -> root
# pre-backward hook order
# root -> decoder 5...0 -> encoder 5...0
# prefetch order
# decoder 5...0 -> encoder 5...0 -> None
# None: when current_handle=encoder 0,
# _get_handle_to_prefetch returns None
# +1 is for the above None
encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_PRE
self.assertEqual(
len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_PRE + 1
)
elif backward_prefetch == BackwardPrefetch.BACKWARD_POST:
# state._exec_order_data.handles_post_forward_order
# equals forward order (same as BACKWARD_PRE)
# encoder 0...5 -> decoder 0...5 -> root
# post-backward hook (AccumulateGrad) order
# decoder 5, 4...0 -> encoder 5...0 -> root
# prefetch order
# decoder 4...0 -> encoder 5...0 -> None -> None
# 1st None: when current_handle=encoder 0,
# _get_handle_to_prefetch returns None
# 2nd None: when current_handle=root,
# get decoder 5 inside _get_handle_to_prefetch
# but not needed since decoder 5 is computed already
# +2 is for the above Nones
encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_POST
self.assertEqual(
len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_POST + 2
)
# ith_prefetch: 0, 1st, 2nd, 3rd, 4th ... ith prefetch
for ith_prefetch, fqns in enumerate(all_handle_fqns):
if ith_prefetch >= 0 and ith_prefetch < encoder_begin_index:
layer_index = encoder_begin_index - 1 - ith_prefetch
self.assertEqual(
fqns,
[x.format(index=layer_index) for x in DECODER_PARAM_FQNS],
)
elif (
ith_prefetch >= encoder_begin_index
and ith_prefetch <= encoder_begin_index + ENCODER_PREFETCH_NUM
):
layer_index = (
encoder_begin_index + ENCODER_PREFETCH_NUM - ith_prefetch
)
self.assertEqual(
fqns,
[x.format(index=layer_index) for x in ENCODER_PARAM_FQNS],
)
else:
self.assertTrue(fqns is None)
all_handle_fqns = []
@skip_if_lt_x_gpu(2)
def test_backward_prefetch(self):
# subtest reuse process group to shorten test time
self.run_subtests(
{
"backward_prefetch": [
None,
BackwardPrefetch.BACKWARD_PRE,
BackwardPrefetch.BACKWARD_POST,
],
},
self._test_backward_prefetch,
)
def _test_backward_prefetch(self, backward_prefetch: BackwardPrefetch):
self._dist_train(backward_prefetch)
if __name__ == "__main__":
run_tests()