pytorch/test/distributed/fsdp/test_fsdp_multiple_wrapping.py
PyTorch MergeBot c85323c5e8 Revert "Tests Generelization for multiple accelerator devices (#139184)"
This reverts commit b576a8c318.

Reverted https://github.com/pytorch/pytorch/pull/139184 on behalf of https://github.com/clee2000 due to Failing internally when trying to pickle distributed test files D67098795 ([comment](https://github.com/pytorch/pytorch/pull/139184#issuecomment-2539610187))
2024-12-12 17:48:30 +00:00

67 lines
2.0 KiB
Python

# Owner(s): ["oncall: distributed"]
import sys
import torch
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn import Linear, Module, Sequential
from torch.optim import SGD
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class InnerModel(Module):
def __init__(self) -> None:
super().__init__()
self.layers = Sequential(FSDP(Linear(5, 5)))
def forward(self, x):
return self.layers(x)
class TestMultipleWrapping(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_multiple_wrapping(self):
"""
This test simulates wrapping the module after training to run inference.
This is required in cases where later in a session, the model is wrapped again in FSDP but
contains nested FSDP wrappers within the module.
"""
inner_model = InnerModel()
model = FSDP(inner_model).cuda()
optim = SGD(model.parameters(), lr=0.1)
for i in range(3):
input = torch.rand((1, 5), dtype=torch.float).cuda()
input.requires_grad = True
output = model(input)
output.sum().backward()
optim.step()
optim.zero_grad()
input = torch.rand((1, 5), dtype=torch.float).cuda()
output = model(input)
# second time to rewrap the inner model
rewrapped_model = FSDP(inner_model).cuda()
rewrapped_output = rewrapped_model(input)
self.assertEqual(output, rewrapped_output)
if __name__ == "__main__":
run_tests()