mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D34111109: [FSDP] Implement apply()
Test Plan: revert-hammer Differential Revision: D34111109 (1f29b3130a) Original commit changeset: 60d9d3f5c4d6 Original Phabricator Diff: D34111109 (1f29b3130a) fbshipit-source-id: d959533f656a1fa69b2af7c029130f674fdd6023 (cherry picked from commitb0d3e2b1c3)
This commit is contained in:
parent
17b3ba148d
commit
ccdff4c480
|
|
@ -1,108 +0,0 @@
|
|||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
FSDPTest,
|
||||
NestedWrappedModule,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
run_tests,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import (
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
class TestApply(FSDPTest):
|
||||
@property
|
||||
def world_size(self):
|
||||
return 2
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_linear_weights(self, m):
|
||||
if type(m) == nn.Linear:
|
||||
m.weight.fill_(1.0)
|
||||
m.bias.fill_(1.0)
|
||||
|
||||
@property
|
||||
def process_group(self):
|
||||
return dist.distributed_c10d._get_default_group()
|
||||
|
||||
def check_weights(self, fsdp, expected_tensor_fn, check):
|
||||
with fsdp._summon_full_params(recurse=True):
|
||||
linear_modules = [
|
||||
module for module in fsdp.modules() if type(module) == nn.Linear
|
||||
]
|
||||
for module in linear_modules:
|
||||
for param in module.parameters():
|
||||
expected = expected_tensor_fn(param)
|
||||
check(param, expected)
|
||||
|
||||
def _check_apply(self, fsdp):
|
||||
# Assert linear weights are not all 1.0
|
||||
self.check_weights(
|
||||
fsdp, lambda param: torch.ones_like(param), self.assertNotEqual
|
||||
)
|
||||
|
||||
fsdp.apply(self._init_linear_weights)
|
||||
|
||||
# Ensure all weights are 1.0
|
||||
self.check_weights(fsdp, lambda param: torch.ones_like(param), self.assertEqual)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_nested_module_apply(self):
|
||||
"""
|
||||
Checks apply() modifies weights appropriately on a nested FSDP instance.
|
||||
"""
|
||||
nested_module = NestedWrappedModule(
|
||||
self.process_group, wrap_fsdp=True, wrap_everything=True
|
||||
)
|
||||
fsdp_module = FSDP(nested_module, self.process_group).cuda(self.rank)
|
||||
self._check_apply(fsdp_module)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_transformer_module_apply(self):
|
||||
"""
|
||||
Checks apply() modifiees weights appropriately on a wrapped Transformer
|
||||
module.
|
||||
"""
|
||||
transformer = self._get_wrapped_model(group=self.process_group).cuda(self.rank)
|
||||
# Assert linear weights are not all 1.0
|
||||
self.check_weights(
|
||||
transformer, lambda param: torch.ones_like(param), self.assertNotEqual
|
||||
)
|
||||
transformer.apply(self._init_linear_weights)
|
||||
# Assert all weights are 1.0
|
||||
self.check_weights(
|
||||
transformer, lambda param: torch.ones_like(param), self.assertEqual
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_apply_in_summon_raises_error(self):
|
||||
"""
|
||||
Ensures that if user calls apply() on FSDP instance within full param
|
||||
summon context, appropriate error is raised.
|
||||
"""
|
||||
transformer = self._get_wrapped_model(group=self.process_group).cuda(self.rank)
|
||||
with transformer._summon_full_params(recurse=True):
|
||||
with self.assertRaisesRegex(ValueError, "expected to be in states"):
|
||||
transformer.apply(self._init_linear_weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -290,45 +290,6 @@ class FullyShardedDataParallel(nn.Module):
|
|||
assert isinstance(self._fsdp_wrapped_module, FlattenParamsWrapper)
|
||||
return self._fsdp_wrapped_module
|
||||
|
||||
def fsdp_modules(self) -> List["FullyShardedDataParallel"]:
|
||||
"""
|
||||
Helper function to return all nested FSDP instances, including self.
|
||||
"""
|
||||
fsdp_modules = []
|
||||
for module in self.modules():
|
||||
if isinstance(module, FullyShardedDataParallel):
|
||||
fsdp_modules.append(module)
|
||||
|
||||
return fsdp_modules
|
||||
|
||||
def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
|
||||
r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``)
|
||||
as well as self. Typical use includes initializing the parameters of a model
|
||||
(see also :ref:`nn-init-doc`).
|
||||
|
||||
Compared to ``torch.nn.Module.apply``, this version additionally gathers
|
||||
the full parameters before applying ``fn``. It should not be called from
|
||||
within another ``summon_full_params`` context.
|
||||
|
||||
Args:
|
||||
fn (:class:`Module` -> None): function to be applied to each submodule
|
||||
|
||||
Returns:
|
||||
Module: self
|
||||
"""
|
||||
uninitialized = self._is_root is None
|
||||
self._assert_state(TrainingState_.IDLE)
|
||||
with self._summon_full_params(recurse=False):
|
||||
ret = super().apply(fn)
|
||||
|
||||
# Reset lazy init that might be called by summon_full_params, since
|
||||
# it could have set is_root incorrectly for non-root FSDP instances.
|
||||
if uninitialized and self._is_root:
|
||||
for module in self.fsdp_modules():
|
||||
module._reset_lazy_init()
|
||||
|
||||
return ret
|
||||
|
||||
# setting two factors 'self.gradient_predivide_factor'
|
||||
# and 'self.gradient_postdivide_factor' to avoid underflow and overflow
|
||||
def _get_gradient_predivide_factor(self, world_size: int) -> float:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user