mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This reverts commit b576a8c318.
Reverted https://github.com/pytorch/pytorch/pull/139184 on behalf of https://github.com/clee2000 due to Failing internally when trying to pickle distributed test files D67098795 ([comment](https://github.com/pytorch/pytorch/pull/139184#issuecomment-2539610187))
221 lines
8.8 KiB
Python
221 lines
8.8 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import sys
|
|
from typing import List
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import distributed as dist
|
|
from torch.distributed.fsdp import BackwardPrefetch, FullyShardedDataParallel as FSDP
|
|
from torch.distributed.fsdp._common_utils import _get_handle_fqns_from_root
|
|
from torch.distributed.fsdp._flat_param import HandleTrainingState
|
|
from torch.distributed.fsdp._runtime_utils import (
|
|
_get_handle_to_prefetch,
|
|
_get_training_state,
|
|
)
|
|
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
|
from torch.testing._internal.common_fsdp import FSDPTest
|
|
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
|
|
|
|
|
NUM_ITERS = 2
|
|
DECODER_PARAM_FQNS = [
|
|
"decoder.layers.{index}.self_attn.in_proj_weight",
|
|
"decoder.layers.{index}.self_attn.in_proj_bias",
|
|
"decoder.layers.{index}.self_attn.out_proj.weight",
|
|
"decoder.layers.{index}.self_attn.out_proj.bias",
|
|
"decoder.layers.{index}.multihead_attn.in_proj_weight",
|
|
"decoder.layers.{index}.multihead_attn.in_proj_bias",
|
|
"decoder.layers.{index}.multihead_attn.out_proj.weight",
|
|
"decoder.layers.{index}.multihead_attn.out_proj.bias",
|
|
"decoder.layers.{index}.linear1.weight",
|
|
"decoder.layers.{index}.linear1.bias",
|
|
"decoder.layers.{index}.linear2.weight",
|
|
"decoder.layers.{index}.linear2.bias",
|
|
"decoder.layers.{index}.norm1.weight",
|
|
"decoder.layers.{index}.norm1.bias",
|
|
"decoder.layers.{index}.norm2.weight",
|
|
"decoder.layers.{index}.norm2.bias",
|
|
"decoder.layers.{index}.norm3.weight",
|
|
"decoder.layers.{index}.norm3.bias",
|
|
]
|
|
ENCODER_PARAM_FQNS = [
|
|
"encoder.layers.{index}.self_attn.in_proj_weight",
|
|
"encoder.layers.{index}.self_attn.in_proj_bias",
|
|
"encoder.layers.{index}.self_attn.out_proj.weight",
|
|
"encoder.layers.{index}.self_attn.out_proj.bias",
|
|
"encoder.layers.{index}.linear1.weight",
|
|
"encoder.layers.{index}.linear1.bias",
|
|
"encoder.layers.{index}.linear2.weight",
|
|
"encoder.layers.{index}.linear2.bias",
|
|
"encoder.layers.{index}.norm1.weight",
|
|
"encoder.layers.{index}.norm1.bias",
|
|
"encoder.layers.{index}.norm2.weight",
|
|
"encoder.layers.{index}.norm2.bias",
|
|
]
|
|
TOTAL_NUM_PREFETCH_FOR_PRE = 12
|
|
TOTAL_NUM_PREFETCH_FOR_POST = 11
|
|
ENCODER_BEGIN_INDEX_FOR_PRE = 6
|
|
ENCODER_BEGIN_INDEX_FOR_POST = 5
|
|
ENCODER_PREFETCH_NUM = 5
|
|
|
|
if not dist.is_available():
|
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
if TEST_WITH_DEV_DBG_ASAN:
|
|
print(
|
|
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(0)
|
|
|
|
|
|
class TestBackwardPrefetch(FSDPTest):
|
|
@property
|
|
def world_size(self):
|
|
return 2
|
|
|
|
def _dist_train(self, backward_prefetch=BackwardPrefetch.BACKWARD_PRE):
|
|
rank = self.rank
|
|
orig_get_handle_to_prefetch = _get_handle_to_prefetch
|
|
|
|
torch.manual_seed(0)
|
|
policy = ModuleWrapPolicy(
|
|
{nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
|
|
)
|
|
model = FSDP(
|
|
nn.Transformer(d_model=1024, nhead=8, device="cuda"),
|
|
device_id=torch.cuda.current_device(),
|
|
auto_wrap_policy=policy,
|
|
use_orig_params=True,
|
|
backward_prefetch=backward_prefetch,
|
|
)
|
|
optim = torch.optim.SGD(model.parameters(), lr=1e-2)
|
|
|
|
# prepare input
|
|
torch.manual_seed(rank + 1)
|
|
src = torch.randn((10, 1, 1024), device="cuda")
|
|
tgt = torch.randn((20, 1, 1024), device="cuda")
|
|
|
|
# monkey patch
|
|
all_handle_fqns: List[List[str]] = []
|
|
|
|
def patched_get_handle_to_prefetch(*args, **kwargs):
|
|
handle = orig_get_handle_to_prefetch(*args, **kwargs)
|
|
|
|
self.assertEqual(
|
|
len(args), 2, "expect _get_handle_to_prefetch(state, current_handle)"
|
|
)
|
|
state = args[0]
|
|
current_handle = args[1]
|
|
training_state = _get_training_state(current_handle)
|
|
if (
|
|
training_state == HandleTrainingState.BACKWARD_PRE
|
|
and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
|
|
) or (
|
|
training_state == HandleTrainingState.BACKWARD_POST
|
|
and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST
|
|
):
|
|
nonlocal all_handle_fqns
|
|
# FQNs prefixed from the root module
|
|
# state._exec_order_data.param_to_fqn
|
|
fqns = _get_handle_fqns_from_root(state, handle)
|
|
all_handle_fqns.append(fqns)
|
|
return handle
|
|
|
|
# flat params from prefetch handle should match
|
|
# DECODER_PARAM_FQNS and ENCODER_PARAM_FQNS
|
|
with patch(
|
|
"torch.distributed.fsdp._runtime_utils._get_handle_to_prefetch",
|
|
patched_get_handle_to_prefetch,
|
|
):
|
|
for _ in range(NUM_ITERS):
|
|
optim.zero_grad()
|
|
loss = model(src, tgt).sum()
|
|
loss.backward()
|
|
optim.step()
|
|
if backward_prefetch is None:
|
|
self.assertEqual(len(all_handle_fqns), 0)
|
|
continue
|
|
elif backward_prefetch == BackwardPrefetch.BACKWARD_PRE:
|
|
# state._exec_order_data.handles_post_forward_order
|
|
# equals forward order
|
|
# encoder 0...5 -> decoder 0...5 -> root
|
|
# pre-backward hook order
|
|
# root -> decoder 5...0 -> encoder 5...0
|
|
# prefetch order
|
|
# decoder 5...0 -> encoder 5...0 -> None
|
|
# None: when current_handle=encoder 0,
|
|
# _get_handle_to_prefetch returns None
|
|
# +1 is for the above None
|
|
encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_PRE
|
|
self.assertEqual(
|
|
len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_PRE + 1
|
|
)
|
|
elif backward_prefetch == BackwardPrefetch.BACKWARD_POST:
|
|
# state._exec_order_data.handles_post_forward_order
|
|
# equals forward order (same as BACKWARD_PRE)
|
|
# encoder 0...5 -> decoder 0...5 -> root
|
|
# post-backward hook (AccumulateGrad) order
|
|
# decoder 5, 4...0 -> encoder 5...0 -> root
|
|
# prefetch order
|
|
# decoder 4...0 -> encoder 5...0 -> None -> None
|
|
# 1st None: when current_handle=encoder 0,
|
|
# _get_handle_to_prefetch returns None
|
|
# 2nd None: when current_handle=root,
|
|
# get decoder 5 inside _get_handle_to_prefetch
|
|
# but not needed since decoder 5 is computed already
|
|
# +2 is for the above Nones
|
|
encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_POST
|
|
self.assertEqual(
|
|
len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_POST + 2
|
|
)
|
|
|
|
# ith_prefetch: 0, 1st, 2nd, 3rd, 4th ... ith prefetch
|
|
for ith_prefetch, fqns in enumerate(all_handle_fqns):
|
|
if ith_prefetch >= 0 and ith_prefetch < encoder_begin_index:
|
|
layer_index = encoder_begin_index - 1 - ith_prefetch
|
|
self.assertEqual(
|
|
fqns,
|
|
[x.format(index=layer_index) for x in DECODER_PARAM_FQNS],
|
|
)
|
|
elif (
|
|
ith_prefetch >= encoder_begin_index
|
|
and ith_prefetch <= encoder_begin_index + ENCODER_PREFETCH_NUM
|
|
):
|
|
layer_index = (
|
|
encoder_begin_index + ENCODER_PREFETCH_NUM - ith_prefetch
|
|
)
|
|
self.assertEqual(
|
|
fqns,
|
|
[x.format(index=layer_index) for x in ENCODER_PARAM_FQNS],
|
|
)
|
|
else:
|
|
self.assertTrue(fqns is None)
|
|
|
|
all_handle_fqns = []
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_backward_prefetch(self):
|
|
# subtest reuse process group to shorten test time
|
|
self.run_subtests(
|
|
{
|
|
"backward_prefetch": [
|
|
None,
|
|
BackwardPrefetch.BACKWARD_PRE,
|
|
BackwardPrefetch.BACKWARD_POST,
|
|
],
|
|
},
|
|
self._test_backward_prefetch,
|
|
)
|
|
|
|
def _test_backward_prefetch(self, backward_prefetch: BackwardPrefetch):
|
|
self._dist_train(backward_prefetch)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|