pytorch/test/distributed/checkpoint/test_save_load_api.py
Chien-Chin Huang d947534782 [DCP] Enable filesystem/fsspec auto detection (#118888)
This API enables the ability to automatically detect whether to use filesystem or fsspec based on the checkpoint_id.

Differential Revision: [D53318043](https://our.internmc.facebook.com/intern/diff/D53318043/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118888
Approved by: https://github.com/wz337, https://github.com/LucasLLC
2024-02-08 16:38:04 +00:00

66 lines
2.3 KiB
Python

# Owner(s): ["oncall: distributed"]
import os
from unittest.mock import patch
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
class MyTestModule(nn.Module):
def __init__(self):
super().__init__()
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
self.net3 = nn.Linear(32, 64)
self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))
def forward(self, x):
return self.net4(self.net3(self.net2(self.net1(x))))
class TestSaveAndLoadAPI(DTensorTestBase):
@property
def world_size(self) -> int:
return 2
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_auto_detect(self):
model = FSDP(MyTestModule().cuda())
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model = FSDP(model, device_mesh=device_mesh)
dcp.save(model.state_dict(), checkpoint_id=os.path.join(self.temp_dir, "first"))
sd = dcp.load(
model.state_dict(), checkpoint_id=os.path.join(self.temp_dir, "first")
)
with patch.object(dcp.FileSystemReader, "check", return_value=False) as m1:
with patch.object(dcp.FileSystemWriter, "check", return_value=False) as m2:
dcp.save(
model.state_dict(),
checkpoint_id=os.path.join(self.temp_dir, "second"),
)
sd = dcp.load(
model.state_dict(),
checkpoint_id=os.path.join(self.temp_dir, "second"),
)
with self.assertRaisesRegex(RuntimeError, "Cannot detect"):
dcp.save(model.state_dict(), checkpoint_id="abc://abc.abc")
with self.assertRaisesRegex(RuntimeError, "Cannot detect"):
sd = dcp.load(model.state_dict(), checkpoint_id="abc://abc.abc")
if __name__ == "__main__":
run_tests()