pytorch/test/distributed/fsdp/test_fsdp_backward_prefetch.py
PyTorch MergeBot c85323c5e8 Revert "Tests Generelization for multiple accelerator devices (#139184)"
This reverts commit b576a8c318.

Reverted https://github.com/pytorch/pytorch/pull/139184 on behalf of https://github.com/clee2000 due to Failing internally when trying to pickle distributed test files D67098795 ([comment](https://github.com/pytorch/pytorch/pull/139184#issuecomment-2539610187))
2024-12-12 17:48:30 +00:00

221 lines
8.8 KiB
Python

# Owner(s): ["oncall: distributed"]
import sys
from typing import List
from unittest.mock import patch
import torch
import torch.nn as nn
from torch import distributed as dist
from torch.distributed.fsdp import BackwardPrefetch, FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_handle_fqns_from_root
from torch.distributed.fsdp._flat_param import HandleTrainingState
from torch.distributed.fsdp._runtime_utils import (
_get_handle_to_prefetch,
_get_training_state,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
NUM_ITERS = 2
DECODER_PARAM_FQNS = [
"decoder.layers.{index}.self_attn.in_proj_weight",
"decoder.layers.{index}.self_attn.in_proj_bias",
"decoder.layers.{index}.self_attn.out_proj.weight",
"decoder.layers.{index}.self_attn.out_proj.bias",
"decoder.layers.{index}.multihead_attn.in_proj_weight",
"decoder.layers.{index}.multihead_attn.in_proj_bias",
"decoder.layers.{index}.multihead_attn.out_proj.weight",
"decoder.layers.{index}.multihead_attn.out_proj.bias",
"decoder.layers.{index}.linear1.weight",
"decoder.layers.{index}.linear1.bias",
"decoder.layers.{index}.linear2.weight",
"decoder.layers.{index}.linear2.bias",
"decoder.layers.{index}.norm1.weight",
"decoder.layers.{index}.norm1.bias",
"decoder.layers.{index}.norm2.weight",
"decoder.layers.{index}.norm2.bias",
"decoder.layers.{index}.norm3.weight",
"decoder.layers.{index}.norm3.bias",
]
ENCODER_PARAM_FQNS = [
"encoder.layers.{index}.self_attn.in_proj_weight",
"encoder.layers.{index}.self_attn.in_proj_bias",
"encoder.layers.{index}.self_attn.out_proj.weight",
"encoder.layers.{index}.self_attn.out_proj.bias",
"encoder.layers.{index}.linear1.weight",
"encoder.layers.{index}.linear1.bias",
"encoder.layers.{index}.linear2.weight",
"encoder.layers.{index}.linear2.bias",
"encoder.layers.{index}.norm1.weight",
"encoder.layers.{index}.norm1.bias",
"encoder.layers.{index}.norm2.weight",
"encoder.layers.{index}.norm2.bias",
]
TOTAL_NUM_PREFETCH_FOR_PRE = 12
TOTAL_NUM_PREFETCH_FOR_POST = 11
ENCODER_BEGIN_INDEX_FOR_PRE = 6
ENCODER_BEGIN_INDEX_FOR_POST = 5
ENCODER_PREFETCH_NUM = 5
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestBackwardPrefetch(FSDPTest):
@property
def world_size(self):
return 2
def _dist_train(self, backward_prefetch=BackwardPrefetch.BACKWARD_PRE):
rank = self.rank
orig_get_handle_to_prefetch = _get_handle_to_prefetch
torch.manual_seed(0)
policy = ModuleWrapPolicy(
{nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
)
model = FSDP(
nn.Transformer(d_model=1024, nhead=8, device="cuda"),
device_id=torch.cuda.current_device(),
auto_wrap_policy=policy,
use_orig_params=True,
backward_prefetch=backward_prefetch,
)
optim = torch.optim.SGD(model.parameters(), lr=1e-2)
# prepare input
torch.manual_seed(rank + 1)
src = torch.randn((10, 1, 1024), device="cuda")
tgt = torch.randn((20, 1, 1024), device="cuda")
# monkey patch
all_handle_fqns: List[List[str]] = []
def patched_get_handle_to_prefetch(*args, **kwargs):
handle = orig_get_handle_to_prefetch(*args, **kwargs)
self.assertEqual(
len(args), 2, "expect _get_handle_to_prefetch(state, current_handle)"
)
state = args[0]
current_handle = args[1]
training_state = _get_training_state(current_handle)
if (
training_state == HandleTrainingState.BACKWARD_PRE
and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
) or (
training_state == HandleTrainingState.BACKWARD_POST
and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST
):
nonlocal all_handle_fqns
# FQNs prefixed from the root module
# state._exec_order_data.param_to_fqn
fqns = _get_handle_fqns_from_root(state, handle)
all_handle_fqns.append(fqns)
return handle
# flat params from prefetch handle should match
# DECODER_PARAM_FQNS and ENCODER_PARAM_FQNS
with patch(
"torch.distributed.fsdp._runtime_utils._get_handle_to_prefetch",
patched_get_handle_to_prefetch,
):
for _ in range(NUM_ITERS):
optim.zero_grad()
loss = model(src, tgt).sum()
loss.backward()
optim.step()
if backward_prefetch is None:
self.assertEqual(len(all_handle_fqns), 0)
continue
elif backward_prefetch == BackwardPrefetch.BACKWARD_PRE:
# state._exec_order_data.handles_post_forward_order
# equals forward order
# encoder 0...5 -> decoder 0...5 -> root
# pre-backward hook order
# root -> decoder 5...0 -> encoder 5...0
# prefetch order
# decoder 5...0 -> encoder 5...0 -> None
# None: when current_handle=encoder 0,
# _get_handle_to_prefetch returns None
# +1 is for the above None
encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_PRE
self.assertEqual(
len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_PRE + 1
)
elif backward_prefetch == BackwardPrefetch.BACKWARD_POST:
# state._exec_order_data.handles_post_forward_order
# equals forward order (same as BACKWARD_PRE)
# encoder 0...5 -> decoder 0...5 -> root
# post-backward hook (AccumulateGrad) order
# decoder 5, 4...0 -> encoder 5...0 -> root
# prefetch order
# decoder 4...0 -> encoder 5...0 -> None -> None
# 1st None: when current_handle=encoder 0,
# _get_handle_to_prefetch returns None
# 2nd None: when current_handle=root,
# get decoder 5 inside _get_handle_to_prefetch
# but not needed since decoder 5 is computed already
# +2 is for the above Nones
encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_POST
self.assertEqual(
len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_POST + 2
)
# ith_prefetch: 0, 1st, 2nd, 3rd, 4th ... ith prefetch
for ith_prefetch, fqns in enumerate(all_handle_fqns):
if ith_prefetch >= 0 and ith_prefetch < encoder_begin_index:
layer_index = encoder_begin_index - 1 - ith_prefetch
self.assertEqual(
fqns,
[x.format(index=layer_index) for x in DECODER_PARAM_FQNS],
)
elif (
ith_prefetch >= encoder_begin_index
and ith_prefetch <= encoder_begin_index + ENCODER_PREFETCH_NUM
):
layer_index = (
encoder_begin_index + ENCODER_PREFETCH_NUM - ith_prefetch
)
self.assertEqual(
fqns,
[x.format(index=layer_index) for x in ENCODER_PARAM_FQNS],
)
else:
self.assertTrue(fqns is None)
all_handle_fqns = []
@skip_if_lt_x_gpu(2)
def test_backward_prefetch(self):
# subtest reuse process group to shorten test time
self.run_subtests(
{
"backward_prefetch": [
None,
BackwardPrefetch.BACKWARD_PRE,
BackwardPrefetch.BACKWARD_POST,
],
},
self._test_backward_prefetch,
)
def _test_backward_prefetch(self, backward_prefetch: BackwardPrefetch):
self._dist_train(backward_prefetch)
if __name__ == "__main__":
run_tests()