[Reland][FSDP] Implement apply() (#72925)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72925

Reland with fix to add the owner string in test file
ghstack-source-id: 149280348

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D34273858

fbshipit-source-id: 2174c1d71fcc5148282d94e375071a50b92114f2
This commit is contained in:
Rohan Varma 2022-02-17 13:28:44 -08:00 committed by Facebook GitHub Bot
parent 57e9b034aa
commit 158762bbb3
2 changed files with 143 additions and 0 deletions

View File

@ -0,0 +1,104 @@
# Owner(s): ["oncall: distributed"]
import sys
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import (
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_fsdp import (
FSDPTest,
NestedWrappedModule,
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
run_tests,
)
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestApply(FSDPTest):
@property
def world_size(self):
return 2
@torch.no_grad()
def _init_linear_weights(self, m):
if type(m) == nn.Linear:
m.weight.fill_(1.0)
m.bias.fill_(1.0)
@property
def process_group(self):
return dist.distributed_c10d._get_default_group()
def check_weights(self, fsdp, expected_tensor_fn, check):
with fsdp._summon_full_params(recurse=True):
linear_modules = [
module for module in fsdp.modules() if type(module) == nn.Linear
]
for module in linear_modules:
for param in module.parameters():
expected = expected_tensor_fn(param)
check(param, expected, f"Got {param} but expected {expected}")
def _check_apply(self, fsdp):
# Assert linear weights are not all 1.0
self.check_weights(
fsdp, lambda param: torch.empty_like(param).fill_(1.0), self.assertNotEqual
)
fsdp.apply(self._init_linear_weights)
# Ensure all weights are 1.0
self.check_weights(
fsdp, lambda param: torch.empty_like(param).fill_(1.0), self.assertEqual
)
@skip_if_lt_x_gpu(2)
def test_nested_module_apply(self):
"""
Checks apply() modifies weights appropriately on a nested FSDP instance.
"""
nested_module = NestedWrappedModule(
self.process_group, wrap_fsdp=True, wrap_everything=True
)
fsdp_module = FSDP(nested_module, self.process_group).cuda(self.rank)
self._check_apply(fsdp_module)
@skip_if_lt_x_gpu(2)
def test_transformer_module_apply(self):
"""
Checks apply() modifies weights appropriately on a wrapped Transformer
module.
"""
transformer = self._get_wrapped_model(group=self.process_group).cuda(self.rank)
self._check_apply(transformer)
@skip_if_lt_x_gpu(2)
def test_apply_in_summon_raises_error(self):
"""
Ensures that if user calls apply() on FSDP instance within full param
summon context, appropriate error is raised.
"""
transformer = self._get_wrapped_model(group=self.process_group).cuda(self.rank)
with transformer._summon_full_params(recurse=True):
with self.assertRaisesRegex(ValueError, "expected to be in states"):
transformer.apply(self._init_linear_weights)
if __name__ == "__main__":
run_tests()

View File

@ -290,6 +290,45 @@ class FullyShardedDataParallel(nn.Module):
assert isinstance(self._fsdp_wrapped_module, FlattenParamsWrapper)
return self._fsdp_wrapped_module
def fsdp_modules(self) -> List["FullyShardedDataParallel"]:
"""
Helper function to return all nested FSDP instances, including self.
"""
fsdp_modules = []
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
fsdp_modules.append(module)
return fsdp_modules
def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``)
as well as self. Typical use includes initializing the parameters of a model
(see also :ref:`nn-init-doc`).
Compared to ``torch.nn.Module.apply``, this version additionally gathers
the full parameters before applying ``fn``. It should not be called from
within another ``summon_full_params`` context.
Args:
fn (:class:`Module` -> None): function to be applied to each submodule
Returns:
Module: self
"""
uninitialized = self._is_root is None
self._assert_state(TrainingState_.IDLE)
with self._summon_full_params(recurse=False):
ret = super().apply(fn)
# Reset lazy init that might be called by summon_full_params, since
# it could have set is_root incorrectly for non-root FSDP instances.
if uninitialized and self._is_root:
for module in self.fsdp_modules():
module._reset_lazy_init()
return ret
# setting two factors 'self.gradient_predivide_factor'
# and 'self.gradient_postdivide_factor' to avoid underflow and overflow
def _get_gradient_predivide_factor(self, world_size: int) -> float: