[FSDP] verify backward_prefetch works correctly with unit test (#107058)

issue resolved: https://github.com/pytorch/pytorch/pull/105984

context:
* CI did not catch the commit that breaks backward_prefetch https://github.com/pytorch/pytorch/pull/105006
* we had an action item to add unit test to prevent similar cases: https://github.com/pytorch/pytorch/pull/105984

what's included in this unit test
* monkey patch
torch.distributed.fsdp._runtime_utils._get_handle_to_prefetch and check which handles are prefetched

for backward_prefetch = BackwardPrefetch.BACKWARD_PRE
* state._exec_order_data.handles_post_forward_order equals forward order: encoder 0...5 -> decoder 0...5 -> root
* pre-backward hook order: root -> decoder 5...0 -> encoder 5...0
* prefetch order: decoder 5...0 -> encoder 5...0 -> None
  * when current_handle=encoder 0, _get_handle_to_prefetch returns None

for backward_prefetch = BackwardPrefetch.BACKWARD_POST
* state._exec_order_data.handles_post_forward_order equals forward order: encoder 0...5 -> decoder 0...5 -> root
* post-backward hook (AccumulateGrad) order: decoder 5, 4...0 -> encoder 5...0 -> root
* prefetch order: decoder 4...0 -> encoder 5...0 -> None -> None
  * 1st None: when current_handle=encoder 0, _get_handle_to_prefetch returns None
  * 2nd None: when current_handle=root, we get decoder 5 inside _get_handle_to_prefetch but is not needed. so returns None
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107058
Approved by: https://github.com/awgu
This commit is contained in:
weifengpy 2023-08-24 22:51:02 +00:00 committed by PyTorch MergeBot
parent 485de73004
commit ec10b17cfb
3 changed files with 249 additions and 17 deletions

View File

@ -0,0 +1,220 @@
# Owner(s): ["oncall: distributed"]
import sys
from typing import List
from unittest.mock import patch
import torch
import torch.nn as nn
from torch import distributed as dist
from torch.distributed.fsdp import BackwardPrefetch, FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_handle_fqns_from_root
from torch.distributed.fsdp._runtime_utils import (
_get_handle_to_prefetch,
_get_training_state,
)
from torch.distributed.fsdp.flat_param import HandleTrainingState
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
NUM_ITERS = 2
DECODER_PARAM_FQNS = [
"decoder.layers.{index}.self_attn.in_proj_weight",
"decoder.layers.{index}.self_attn.in_proj_bias",
"decoder.layers.{index}.self_attn.out_proj.weight",
"decoder.layers.{index}.self_attn.out_proj.bias",
"decoder.layers.{index}.multihead_attn.in_proj_weight",
"decoder.layers.{index}.multihead_attn.in_proj_bias",
"decoder.layers.{index}.multihead_attn.out_proj.weight",
"decoder.layers.{index}.multihead_attn.out_proj.bias",
"decoder.layers.{index}.linear1.weight",
"decoder.layers.{index}.linear1.bias",
"decoder.layers.{index}.linear2.weight",
"decoder.layers.{index}.linear2.bias",
"decoder.layers.{index}.norm1.weight",
"decoder.layers.{index}.norm1.bias",
"decoder.layers.{index}.norm2.weight",
"decoder.layers.{index}.norm2.bias",
"decoder.layers.{index}.norm3.weight",
"decoder.layers.{index}.norm3.bias",
]
ENCODER_PARAM_FQNS = [
"encoder.layers.{index}.self_attn.in_proj_weight",
"encoder.layers.{index}.self_attn.in_proj_bias",
"encoder.layers.{index}.self_attn.out_proj.weight",
"encoder.layers.{index}.self_attn.out_proj.bias",
"encoder.layers.{index}.linear1.weight",
"encoder.layers.{index}.linear1.bias",
"encoder.layers.{index}.linear2.weight",
"encoder.layers.{index}.linear2.bias",
"encoder.layers.{index}.norm1.weight",
"encoder.layers.{index}.norm1.bias",
"encoder.layers.{index}.norm2.weight",
"encoder.layers.{index}.norm2.bias",
]
TOTAL_NUM_PREFETCH_FOR_PRE = 12
TOTAL_NUM_PREFETCH_FOR_POST = 11
ENCODER_BEGIN_INDEX_FOR_PRE = 6
ENCODER_BEGIN_INDEX_FOR_POST = 5
ENCODER_PREFETCH_NUM = 5
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestBackwardPrefetch(FSDPTest):
@property
def world_size(self):
return 2
def _dist_train(self, backward_prefetch=BackwardPrefetch.BACKWARD_PRE):
rank = self.rank
orig_get_handle_to_prefetch = _get_handle_to_prefetch
torch.manual_seed(0)
policy = ModuleWrapPolicy(
{nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
)
model = FSDP(
nn.Transformer(d_model=1024, nhead=8, device="cuda"),
device_id=torch.cuda.current_device(),
auto_wrap_policy=policy,
use_orig_params=True,
backward_prefetch=backward_prefetch,
)
optim = torch.optim.SGD(model.parameters(), lr=1e-2)
# prepare input
torch.manual_seed(rank + 1)
src = torch.randn((10, 1, 1024), device="cuda")
tgt = torch.randn((20, 1, 1024), device="cuda")
# monkey patch
all_handle_fqns: List[List[str]] = []
def patched_get_handle_to_prefetch(*args, **kwargs):
handle = orig_get_handle_to_prefetch(*args, **kwargs)
self.assertEqual(
len(args), 2, "expect _get_handle_to_prefetch(state, current_handle)"
)
state = args[0]
current_handle = args[1]
training_state = _get_training_state(current_handle)
if (
training_state == HandleTrainingState.BACKWARD_PRE
and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
) or (
training_state == HandleTrainingState.BACKWARD_POST
and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST
):
nonlocal all_handle_fqns
# FQNs prefixed from the root module
# state._exec_order_data.param_to_fqn
fqns = _get_handle_fqns_from_root(state, handle)
all_handle_fqns.append(fqns)
return handle
# flat params from prefetch handle should match
# DECODER_PARAM_FQNS and ENCODER_PARAM_FQNS
with patch(
"torch.distributed.fsdp._runtime_utils._get_handle_to_prefetch",
patched_get_handle_to_prefetch,
):
for _ in range(NUM_ITERS):
optim.zero_grad()
loss = model(src, tgt).sum()
loss.backward()
optim.step()
if backward_prefetch is None:
self.assertEqual(len(all_handle_fqns), 0)
continue
elif backward_prefetch == BackwardPrefetch.BACKWARD_PRE:
# state._exec_order_data.handles_post_forward_order
# equals forward order
# encoder 0...5 -> decoder 0...5 -> root
# pre-backward hook order
# root -> decoder 5...0 -> encoder 5...0
# prefetch order
# decoder 5...0 -> encoder 5...0 -> None
# None: when current_handle=encoder 0,
# _get_handle_to_prefetch returns None
# +1 is for the above None
encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_PRE
self.assertEqual(
len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_PRE + 1
)
elif backward_prefetch == BackwardPrefetch.BACKWARD_POST:
# state._exec_order_data.handles_post_forward_order
# equals forward order (same as BACKWARD_PRE)
# encoder 0...5 -> decoder 0...5 -> root
# post-backward hook (AccumulateGrad) order
# decoder 5, 4...0 -> encoder 5...0 -> root
# prefetch order
# decoder 4...0 -> encoder 5...0 -> None -> None
# 1st None: when current_handle=encoder 0,
# _get_handle_to_prefetch returns None
# 2nd None: when current_handle=root,
# get decoder 5 inside _get_handle_to_prefetch
# but not needed since decoder 5 is computed already
# +2 is for the above Nones
encoder_begin_index = ENCODER_BEGIN_INDEX_FOR_POST
self.assertEqual(
len(all_handle_fqns), TOTAL_NUM_PREFETCH_FOR_POST + 2
)
# ith_prefetch: 0, 1st, 2nd, 3rd, 4th ... ith prefetch
for ith_prefetch, fqns in enumerate(all_handle_fqns):
if ith_prefetch >= 0 and ith_prefetch < encoder_begin_index:
layer_index = encoder_begin_index - 1 - ith_prefetch
self.assertEqual(
fqns,
[x.format(index=layer_index) for x in DECODER_PARAM_FQNS],
)
elif (
ith_prefetch >= encoder_begin_index
and ith_prefetch <= encoder_begin_index + ENCODER_PREFETCH_NUM
):
layer_index = (
encoder_begin_index + ENCODER_PREFETCH_NUM - ith_prefetch
)
self.assertEqual(
fqns,
[x.format(index=layer_index) for x in ENCODER_PARAM_FQNS],
)
else:
self.assertTrue(fqns is None)
all_handle_fqns = []
@skip_if_lt_x_gpu(2)
def test_backward_prefetch(self):
# subtest reuse process group to shorten test time
self.run_subtests(
{
"backward_prefetch": [
None,
BackwardPrefetch.BACKWARD_PRE,
BackwardPrefetch.BACKWARD_POST,
],
},
self._test_backward_prefetch,
)
def _test_backward_prefetch(self, backward_prefetch: BackwardPrefetch):
self._dist_train(backward_prefetch)
if __name__ == "__main__":
run_tests()

View File

@ -2,6 +2,7 @@
This file includes private common utilities for FSDP.
"""
import logging
import traceback
import warnings
import weakref
@ -351,6 +352,32 @@ def _get_param_to_fqns(
)
@no_type_check
def _log_post_backward_hook(
state: _FSDPState, handle: "FlatParamHandle", log: logging.Logger
) -> None:
# Under TORCH_DISTRIBUTED_DEBUG=INFO, log the module names this hook fires for.
# Below logging of module names this post-bwd hook fires for can help debug certain
# cases where hooks don't fire, such as under certain activation checkpoint configs.
if state._use_orig_params and handle._debug_level == dist.DebugLevel.INFO:
param_fqns = _get_handle_fqns_from_root(state, handle)
log.warning("FSDP firing post-backward hooks for parameters %s", param_fqns)
@no_type_check
def _get_handle_fqns_from_root(
state: _FSDPState, handle: "FlatParamHandle"
) -> Optional[List[str]]:
if handle is None:
return None
param_to_fqn = state._exec_order_data.param_to_fqn
handle_params = handle.flat_param._params # only populated for use_orig_params
param_fqns = [
fqn for fqn_list in [param_to_fqn[p] for p in handle_params] for fqn in fqn_list
]
return param_fqns
def _apply_to_modules(
root_module: torch.nn.Module,
module_fn: Callable,

View File

@ -19,6 +19,7 @@ from torch.distributed.fsdp._common_utils import (
_FSDPState,
_get_module_fsdp_state,
_is_composable,
_log_post_backward_hook,
_no_dispatch_record_stream,
clean_tensor_name,
TrainingState,
@ -734,7 +735,7 @@ def _post_backward_hook(
- Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded
gradient (accumulating with any existing gradient).
"""
_log_post_backward_hook(state, handle)
_log_post_backward_hook(state, handle, log)
flat_param = handle.flat_param
flat_param._post_backward_called = True
with torch.autograd.profiler.record_function(
@ -789,22 +790,6 @@ def _post_backward_hook(
)
@no_type_check
def _log_post_backward_hook(state: _FSDPState, handle: FlatParamHandle) -> None:
# Under TORCH_DISTRIBUTED_DEBUG=INFO, log the module names this hook fires for.
# Below logging of module names this post-bwd hook fires for can help debug certain
# cases where hooks don't fire, such as under certain activation checkpoint configs.
if state._use_orig_params and handle._debug_level == dist.DebugLevel.INFO:
param_to_fqn = state._exec_order_data.param_to_fqn
handle_params = handle.flat_param._params # only populated for use_orig_params
param_fqns = [
param
for param_list in [param_to_fqn[p] for p in handle_params]
for param in param_list
]
log.warning("FSDP firing post-backward hooks for parameters %s", param_fqns)
def _post_backward_reshard(
state: _FSDPState,
handle: FlatParamHandle,