Revert D34111109: [FSDP] Implement apply()

Test Plan: revert-hammer

Differential Revision:
D34111109 (1f29b3130a)

Original commit changeset: 60d9d3f5c4d6

Original Phabricator Diff: D34111109 (1f29b3130a)

fbshipit-source-id: d959533f656a1fa69b2af7c029130f674fdd6023
(cherry picked from commit b0d3e2b1c3)
This commit is contained in:
Nikita Shulga 2022-02-16 07:25:34 -08:00 committed by PyTorch MergeBot
parent 17b3ba148d
commit ccdff4c480
2 changed files with 0 additions and 147 deletions

View File

@ -1,108 +0,0 @@
import sys
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_fsdp import (
FSDPTest,
NestedWrappedModule,
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
run_tests,
)
from torch.testing._internal.common_distributed import (
skip_if_lt_x_gpu,
)
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestApply(FSDPTest):
@property
def world_size(self):
return 2
@torch.no_grad()
def _init_linear_weights(self, m):
if type(m) == nn.Linear:
m.weight.fill_(1.0)
m.bias.fill_(1.0)
@property
def process_group(self):
return dist.distributed_c10d._get_default_group()
def check_weights(self, fsdp, expected_tensor_fn, check):
with fsdp._summon_full_params(recurse=True):
linear_modules = [
module for module in fsdp.modules() if type(module) == nn.Linear
]
for module in linear_modules:
for param in module.parameters():
expected = expected_tensor_fn(param)
check(param, expected)
def _check_apply(self, fsdp):
# Assert linear weights are not all 1.0
self.check_weights(
fsdp, lambda param: torch.ones_like(param), self.assertNotEqual
)
fsdp.apply(self._init_linear_weights)
# Ensure all weights are 1.0
self.check_weights(fsdp, lambda param: torch.ones_like(param), self.assertEqual)
@skip_if_lt_x_gpu(2)
def test_nested_module_apply(self):
"""
Checks apply() modifies weights appropriately on a nested FSDP instance.
"""
nested_module = NestedWrappedModule(
self.process_group, wrap_fsdp=True, wrap_everything=True
)
fsdp_module = FSDP(nested_module, self.process_group).cuda(self.rank)
self._check_apply(fsdp_module)
@skip_if_lt_x_gpu(2)
def test_transformer_module_apply(self):
"""
Checks apply() modifiees weights appropriately on a wrapped Transformer
module.
"""
transformer = self._get_wrapped_model(group=self.process_group).cuda(self.rank)
# Assert linear weights are not all 1.0
self.check_weights(
transformer, lambda param: torch.ones_like(param), self.assertNotEqual
)
transformer.apply(self._init_linear_weights)
# Assert all weights are 1.0
self.check_weights(
transformer, lambda param: torch.ones_like(param), self.assertEqual
)
@skip_if_lt_x_gpu(2)
def test_apply_in_summon_raises_error(self):
"""
Ensures that if user calls apply() on FSDP instance within full param
summon context, appropriate error is raised.
"""
transformer = self._get_wrapped_model(group=self.process_group).cuda(self.rank)
with transformer._summon_full_params(recurse=True):
with self.assertRaisesRegex(ValueError, "expected to be in states"):
transformer.apply(self._init_linear_weights)
if __name__ == "__main__":
run_tests()

View File

@ -290,45 +290,6 @@ class FullyShardedDataParallel(nn.Module):
assert isinstance(self._fsdp_wrapped_module, FlattenParamsWrapper)
return self._fsdp_wrapped_module
def fsdp_modules(self) -> List["FullyShardedDataParallel"]:
"""
Helper function to return all nested FSDP instances, including self.
"""
fsdp_modules = []
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
fsdp_modules.append(module)
return fsdp_modules
def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``)
as well as self. Typical use includes initializing the parameters of a model
(see also :ref:`nn-init-doc`).
Compared to ``torch.nn.Module.apply``, this version additionally gathers
the full parameters before applying ``fn``. It should not be called from
within another ``summon_full_params`` context.
Args:
fn (:class:`Module` -> None): function to be applied to each submodule
Returns:
Module: self
"""
uninitialized = self._is_root is None
self._assert_state(TrainingState_.IDLE)
with self._summon_full_params(recurse=False):
ret = super().apply(fn)
# Reset lazy init that might be called by summon_full_params, since
# it could have set is_root incorrectly for non-root FSDP instances.
if uninitialized and self._is_root:
for module in self.fsdp_modules():
module._reset_lazy_init()
return ret
# setting two factors 'self.gradient_predivide_factor'
# and 'self.gradient_postdivide_factor' to avoid underflow and overflow
def _get_gradient_predivide_factor(self, world_size: int) -> float: