diff --git a/test/distributed/fsdp/test_fsdp_apply.py b/test/distributed/fsdp/test_fsdp_apply.py deleted file mode 100644 index d45fcada027..00000000000 --- a/test/distributed/fsdp/test_fsdp_apply.py +++ /dev/null @@ -1,108 +0,0 @@ -import sys - -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.testing._internal.common_fsdp import ( - FSDPTest, - NestedWrappedModule, -) -from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, - run_tests, -) -from torch.testing._internal.common_distributed import ( - skip_if_lt_x_gpu, -) - -if not dist.is_available(): - print("Distributed not available, skipping tests", file=sys.stderr) - sys.exit(0) - -if TEST_WITH_DEV_DBG_ASAN: - print( - "Skip dev-asan as torch + multiprocessing spawn have known issues", - file=sys.stderr, - ) - sys.exit(0) - - -class TestApply(FSDPTest): - @property - def world_size(self): - return 2 - - @torch.no_grad() - def _init_linear_weights(self, m): - if type(m) == nn.Linear: - m.weight.fill_(1.0) - m.bias.fill_(1.0) - - @property - def process_group(self): - return dist.distributed_c10d._get_default_group() - - def check_weights(self, fsdp, expected_tensor_fn, check): - with fsdp._summon_full_params(recurse=True): - linear_modules = [ - module for module in fsdp.modules() if type(module) == nn.Linear - ] - for module in linear_modules: - for param in module.parameters(): - expected = expected_tensor_fn(param) - check(param, expected) - - def _check_apply(self, fsdp): - # Assert linear weights are not all 1.0 - self.check_weights( - fsdp, lambda param: torch.ones_like(param), self.assertNotEqual - ) - - fsdp.apply(self._init_linear_weights) - - # Ensure all weights are 1.0 - self.check_weights(fsdp, lambda param: torch.ones_like(param), self.assertEqual) - - @skip_if_lt_x_gpu(2) - def test_nested_module_apply(self): - """ - Checks apply() modifies weights appropriately on a nested FSDP instance. - """ - nested_module = NestedWrappedModule( - self.process_group, wrap_fsdp=True, wrap_everything=True - ) - fsdp_module = FSDP(nested_module, self.process_group).cuda(self.rank) - self._check_apply(fsdp_module) - - @skip_if_lt_x_gpu(2) - def test_transformer_module_apply(self): - """ - Checks apply() modifiees weights appropriately on a wrapped Transformer - module. - """ - transformer = self._get_wrapped_model(group=self.process_group).cuda(self.rank) - # Assert linear weights are not all 1.0 - self.check_weights( - transformer, lambda param: torch.ones_like(param), self.assertNotEqual - ) - transformer.apply(self._init_linear_weights) - # Assert all weights are 1.0 - self.check_weights( - transformer, lambda param: torch.ones_like(param), self.assertEqual - ) - - @skip_if_lt_x_gpu(2) - def test_apply_in_summon_raises_error(self): - """ - Ensures that if user calls apply() on FSDP instance within full param - summon context, appropriate error is raised. - """ - transformer = self._get_wrapped_model(group=self.process_group).cuda(self.rank) - with transformer._summon_full_params(recurse=True): - with self.assertRaisesRegex(ValueError, "expected to be in states"): - transformer.apply(self._init_linear_weights) - - -if __name__ == "__main__": - run_tests() diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index fe61684b69d..d270230eba1 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -290,45 +290,6 @@ class FullyShardedDataParallel(nn.Module): assert isinstance(self._fsdp_wrapped_module, FlattenParamsWrapper) return self._fsdp_wrapped_module - def fsdp_modules(self) -> List["FullyShardedDataParallel"]: - """ - Helper function to return all nested FSDP instances, including self. - """ - fsdp_modules = [] - for module in self.modules(): - if isinstance(module, FullyShardedDataParallel): - fsdp_modules.append(module) - - return fsdp_modules - - def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel": - r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``) - as well as self. Typical use includes initializing the parameters of a model - (see also :ref:`nn-init-doc`). - - Compared to ``torch.nn.Module.apply``, this version additionally gathers - the full parameters before applying ``fn``. It should not be called from - within another ``summon_full_params`` context. - - Args: - fn (:class:`Module` -> None): function to be applied to each submodule - - Returns: - Module: self - """ - uninitialized = self._is_root is None - self._assert_state(TrainingState_.IDLE) - with self._summon_full_params(recurse=False): - ret = super().apply(fn) - - # Reset lazy init that might be called by summon_full_params, since - # it could have set is_root incorrectly for non-root FSDP instances. - if uninitialized and self._is_root: - for module in self.fsdp_modules(): - module._reset_lazy_init() - - return ret - # setting two factors 'self.gradient_predivide_factor' # and 'self.gradient_postdivide_factor' to avoid underflow and overflow def _get_gradient_predivide_factor(self, world_size: int) -> float: