mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DCP][BE] Apply ufmt to DCP and turn on lintrunner for DCP (#115302)
No logic change. Just typing and ufmt. Differential Revision: [D51914982](https://our.internmc.facebook.com/intern/diff/D51914982/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/115302 Approved by: https://github.com/XilunWu, https://github.com/wz337, https://github.com/LucasLLC ghstack dependencies: #115523
This commit is contained in:
parent
cc28f61fa3
commit
db8d409d08
|
|
@ -869,7 +869,6 @@ exclude_patterns = [
|
|||
'test/bottleneck_test/**', # excluded by test/run_test.py
|
||||
'test/distributed/argparse_util_test.py',
|
||||
'test/distributed/bin/test_script.py',
|
||||
'test/distributed/checkpoint/e2e/test_pipeline.py',
|
||||
'test/distributed/elastic/agent/server/test/local_elastic_agent_test.py',
|
||||
'test/distributed/elastic/multiprocessing/bin/test_script.py',
|
||||
'test/distributed/elastic/multiprocessing/bin/zombie_test.py',
|
||||
|
|
@ -1094,19 +1093,6 @@ exclude_patterns = [
|
|||
'test/distributed/algorithms/test_join.py',
|
||||
'test/distributed/argparse_util_test.py',
|
||||
'test/distributed/bin/test_script.py',
|
||||
'test/distributed/checkpoint/test_2d_fsdp_dt_checkpoint.py',
|
||||
'test/distributed/checkpoint/test_checkpoint.py',
|
||||
'test/distributed/checkpoint/test_dedup_tensors.py',
|
||||
'test/distributed/checkpoint/test_dtensor_checkpoint.py',
|
||||
'test/distributed/checkpoint/test_file_system_checkpoint.py',
|
||||
'test/distributed/checkpoint/test_file_system_checkpoint_cpu.py',
|
||||
'test/distributed/checkpoint/test_fsdp_model_state.py',
|
||||
'test/distributed/checkpoint/test_fsdp_optim_state.py',
|
||||
'test/distributed/checkpoint/test_fsspec.py',
|
||||
'test/distributed/checkpoint/test_nested_dict.py',
|
||||
'test/distributed/checkpoint/test_planner.py',
|
||||
'test/distributed/checkpoint/test_traverse.py',
|
||||
'test/distributed/checkpoint/test_utils.py',
|
||||
'test/distributed/elastic/agent/server/test/__init__.py',
|
||||
'test/distributed/elastic/agent/server/test/api_test.py',
|
||||
'test/distributed/elastic/agent/server/test/local_elastic_agent_test.py',
|
||||
|
|
@ -2010,25 +1996,6 @@ exclude_patterns = [
|
|||
'torch/distributed/autograd/__init__.py',
|
||||
'torch/distributed/benchmarks/benchmark_ddp_rpc.py',
|
||||
'torch/distributed/c10d_logger.py',
|
||||
'torch/distributed/checkpoint/__init__.py',
|
||||
'torch/distributed/checkpoint/_dedup_tensors.py',
|
||||
'torch/distributed/checkpoint/_fsspec_filesystem.py',
|
||||
'torch/distributed/checkpoint/_nested_dict.py',
|
||||
'torch/distributed/checkpoint/_sharded_tensor_utils.py',
|
||||
'torch/distributed/checkpoint/_traverse.py',
|
||||
'torch/distributed/checkpoint/api.py',
|
||||
'torch/distributed/checkpoint/default_planner.py',
|
||||
'torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py',
|
||||
'torch/distributed/checkpoint/filesystem.py',
|
||||
'torch/distributed/checkpoint/metadata.py',
|
||||
'torch/distributed/checkpoint/optimizer.py',
|
||||
'torch/distributed/checkpoint/planner.py',
|
||||
'torch/distributed/checkpoint/planner_helpers.py',
|
||||
'torch/distributed/checkpoint/resharding.py',
|
||||
'torch/distributed/checkpoint/state_dict_loader.py',
|
||||
'torch/distributed/checkpoint/state_dict_saver.py',
|
||||
'torch/distributed/checkpoint/storage.py',
|
||||
'torch/distributed/checkpoint/utils.py',
|
||||
'torch/distributed/collective_utils.py',
|
||||
'torch/distributed/constants.py',
|
||||
'torch/distributed/distributed_c10d.py',
|
||||
|
|
@ -2442,7 +2409,6 @@ exclude_patterns = [
|
|||
'torch/testing/_internal/distributed/_shard/test_common.py',
|
||||
'torch/testing/_internal/distributed/_tensor/__init__.py',
|
||||
'torch/testing/_internal/distributed/_tensor/common_dtensor.py',
|
||||
'torch/testing/_internal/distributed/checkpoint_utils.py',
|
||||
'torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py',
|
||||
'torch/testing/_internal/distributed/distributed_test.py',
|
||||
'torch/testing/_internal/distributed/distributed_utils.py',
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_di
|
|||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest
|
||||
from torch.testing._internal.common_utils import TEST_WITH_DEV_DBG_ASAN
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
|
||||
|
|
@ -102,3 +102,7 @@ class TestPipeline(FSDPTest):
|
|||
self.assertTrue(os.path.exists(pipeline_dir))
|
||||
self.save_with_pipeline(pipeline_dir)
|
||||
self.load_with_fsdp(pipeline_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,29 +1,28 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
from typing import Optional, List, cast
|
||||
from torch.distributed.checkpoint.storage import WriteResult
|
||||
|
||||
from torch.distributed.checkpoint import (
|
||||
StorageReader,
|
||||
StorageWriter,
|
||||
CheckpointException,
|
||||
load_state_dict,
|
||||
save_state_dict,
|
||||
)
|
||||
from typing import cast, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn
|
||||
import torch.futures
|
||||
from torch.futures import Future
|
||||
import torch.nn
|
||||
|
||||
from torch.distributed._shard import sharded_tensor
|
||||
|
||||
from torch.distributed.checkpoint.default_planner import (
|
||||
_create_default_local_metadata,
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor, state_dict_hook
|
||||
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
|
||||
|
||||
from torch.distributed.checkpoint import (
|
||||
CheckpointException,
|
||||
load_state_dict,
|
||||
save_state_dict,
|
||||
StorageReader,
|
||||
StorageWriter,
|
||||
)
|
||||
|
||||
from torch.distributed.checkpoint.default_planner import _create_default_local_metadata
|
||||
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
BytesStorageMetadata,
|
||||
Metadata,
|
||||
|
|
@ -31,31 +30,21 @@ from torch.distributed.checkpoint.metadata import (
|
|||
)
|
||||
|
||||
from torch.distributed.checkpoint.planner import (
|
||||
SavePlan,
|
||||
SavePlanner,
|
||||
LoadPlan,
|
||||
LoadPlanner,
|
||||
SavePlan,
|
||||
SavePlanner,
|
||||
)
|
||||
from torch.distributed.checkpoint.storage import WriteResult
|
||||
from torch.futures import Future
|
||||
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
|
||||
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
state_dict_hook,
|
||||
ShardedTensor,
|
||||
)
|
||||
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
ShardedTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
run_tests,
|
||||
)
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
|
|
@ -175,9 +164,7 @@ class TestStorageBase:
|
|||
ranks = self._get_ranks(name)
|
||||
fut = Future()
|
||||
if ranks is not None and self.rank in ranks:
|
||||
fut.set_exception(
|
||||
ValueError(f"async rank fail {self.rank} for {name}")
|
||||
)
|
||||
fut.set_exception(ValueError(f"async rank fail {self.rank} for {name}"))
|
||||
else:
|
||||
fut.set_result(result)
|
||||
return fut
|
||||
|
|
@ -204,9 +191,7 @@ class FaultyStorageWriter(TestStorageBase, StorageWriter):
|
|||
self._fail_rank("fail_write_data")
|
||||
return self._fail_rank_async("fail_write_data_async", [])
|
||||
|
||||
def finish(
|
||||
self, metadata: Metadata, results: List[List[WriteResult]]
|
||||
) -> None:
|
||||
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
|
||||
self._fail_rank("fail_finish")
|
||||
|
||||
|
||||
|
|
@ -239,9 +224,7 @@ class TestDistributedFailure(ShardedTensorTestBase):
|
|||
def get_spec(self):
|
||||
return ChunkShardingSpec(
|
||||
dim=0,
|
||||
placements=[
|
||||
f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())
|
||||
],
|
||||
placements=[f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())],
|
||||
)
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import dataclasses
|
||||
|
||||
import torch
|
||||
from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
|
||||
from torch.distributed.checkpoint.planner import SavePlan, WriteItemType
|
||||
from torch.distributed.checkpoint.planner_helpers import (
|
||||
_create_write_item_for_tensor,
|
||||
)
|
||||
from torch.distributed.checkpoint.planner_helpers import _create_write_item_for_tensor
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
|
|
@ -33,9 +32,7 @@ class TestDedupTensor(TestCase):
|
|||
self.assertEqual(2, len(dedup_plans[0].items))
|
||||
self.assertEqual(1, len(dedup_plans[1].items))
|
||||
|
||||
self.assertIn(
|
||||
"tensor_0", (item.index.fqn for item in dedup_plans[0].items)
|
||||
)
|
||||
self.assertIn("tensor_0", (item.index.fqn for item in dedup_plans[0].items))
|
||||
self.assertIn("r0", (item.index.fqn for item in dedup_plans[0].items))
|
||||
|
||||
self.assertIn("r1", (item.index.fqn for item in dedup_plans[1].items))
|
||||
|
|
|
|||
|
|
@ -6,19 +6,19 @@ import torch.distributed as dist
|
|||
import torch.distributed.checkpoint as dist_cp
|
||||
from torch.distributed._tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
Replicate,
|
||||
Shard,
|
||||
distribute_tensor,
|
||||
zeros,
|
||||
)
|
||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
skip_if_lt_x_gpu,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
|
||||
SUBMESH_TENSOR_SIZE = 6
|
||||
|
|
@ -81,9 +81,7 @@ class DTensorPlanner(DTensorTestBase):
|
|||
device_type=self.device_type,
|
||||
mesh=range(dist.get_world_size()),
|
||||
)
|
||||
sharded_dt = distribute_tensor(
|
||||
tensor_to_shard, mesh, placements=[Shard(0)]
|
||||
)
|
||||
sharded_dt = distribute_tensor(tensor_to_shard, mesh, placements=[Shard(0)])
|
||||
replicated_dt = distribute_tensor(
|
||||
tensor_to_replicate, mesh, placements=[Replicate()]
|
||||
)
|
||||
|
|
@ -179,9 +177,7 @@ class DTensorPlanner(DTensorTestBase):
|
|||
storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
|
||||
planner=dist_cp.DefaultSavePlanner(),
|
||||
)
|
||||
model, _, _ = self.create_dtensor_model(
|
||||
local_tensor * 10, local_tensor_2 * 10
|
||||
)
|
||||
model, _, _ = self.create_dtensor_model(local_tensor * 10, local_tensor_2 * 10)
|
||||
state_dict = model.state_dict()
|
||||
|
||||
"""
|
||||
|
|
@ -247,9 +243,7 @@ class DTensorPlanner(DTensorTestBase):
|
|||
if k == "submesh_sdt":
|
||||
if self.rank % 2 == 0:
|
||||
shard_size = int(SUBMESH_TENSOR_SIZE / v.device_mesh.size())
|
||||
self.assertEqual(
|
||||
v.to_local().size(), torch.Size([shard_size])
|
||||
)
|
||||
self.assertEqual(v.to_local().size(), torch.Size([shard_size]))
|
||||
self.assertEqual(v.to_local(), torch.zeros([shard_size]))
|
||||
else:
|
||||
self.assertEqual(v.to_local().size(), torch.Size([0]))
|
||||
|
|
@ -258,9 +252,7 @@ class DTensorPlanner(DTensorTestBase):
|
|||
if k == "submesh_rdt":
|
||||
if self.rank % 2 == 0:
|
||||
shard_size = SUBMESH_TENSOR_SIZE
|
||||
self.assertEqual(
|
||||
v.to_local().size(), torch.Size([shard_size])
|
||||
)
|
||||
self.assertEqual(v.to_local().size(), torch.Size([shard_size]))
|
||||
self.assertEqual(v.to_local(), torch.zeros([shard_size]))
|
||||
else:
|
||||
self.assertEqual(v.to_local().size(), torch.Size([0]))
|
||||
|
|
|
|||
|
|
@ -1,42 +1,21 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._shard import sharded_tensor
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
ShardedTensor,
|
||||
state_dict_hook,
|
||||
)
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor, state_dict_hook
|
||||
from torch.distributed._shard.sharding_spec import (
|
||||
ChunkShardingSpec,
|
||||
EnumerableShardingSpec,
|
||||
ShardingSpec,
|
||||
ShardMetadata,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
ShardedTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
|
||||
MyShardedModel1,
|
||||
)
|
||||
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
run_tests,
|
||||
)
|
||||
|
||||
from torch.distributed.checkpoint import (
|
||||
FileSystemReader,
|
||||
|
|
@ -44,6 +23,20 @@ from torch.distributed.checkpoint import (
|
|||
load_state_dict,
|
||||
save_state_dict,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
ShardedTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
|
||||
MyShardedModel1,
|
||||
)
|
||||
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
|
|
@ -122,9 +115,7 @@ class TestDistributedStateDictSaveLoad(TestCase):
|
|||
state_dict_to_load_to = MyTestModule().state_dict()
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
assert_state_dict_equal(
|
||||
self, state_dict_to_load_to, state_dict_to_save
|
||||
)
|
||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
||||
# Load from file without any resharding
|
||||
fs_reader = FileSystemReader(path=path)
|
||||
|
|
@ -134,9 +125,7 @@ class TestDistributedStateDictSaveLoad(TestCase):
|
|||
no_dist=True,
|
||||
)
|
||||
|
||||
assert_state_dict_equal(
|
||||
self, state_dict_to_load_to, state_dict_to_save
|
||||
)
|
||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
state_dict_to_save = MyTestModule().state_dict()
|
||||
|
|
@ -151,9 +140,7 @@ class TestDistributedStateDictSaveLoad(TestCase):
|
|||
state_dict_to_load_to = MyTestModule().state_dict()
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
assert_state_dict_equal(
|
||||
self, state_dict_to_load_to, state_dict_to_save
|
||||
)
|
||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
||||
# Load from file without any resharding
|
||||
fs_reader = FileSystemReader(path=path)
|
||||
|
|
@ -163,9 +150,7 @@ class TestDistributedStateDictSaveLoad(TestCase):
|
|||
no_dist=True,
|
||||
)
|
||||
|
||||
assert_state_dict_equal(
|
||||
self, state_dict_to_load_to, state_dict_to_save
|
||||
)
|
||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
||||
|
||||
class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
||||
|
|
@ -212,15 +197,11 @@ class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
|
|||
dist.barrier()
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
assert_state_dict_equal(
|
||||
self, state_dict_to_load_to, state_dict_to_save
|
||||
)
|
||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
|
||||
# Test load.
|
||||
fs_reader = FileSystemReader(path=path)
|
||||
load_state_dict(
|
||||
state_dict=state_dict_to_load_to, storage_reader=fs_reader
|
||||
)
|
||||
load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader)
|
||||
|
||||
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
|
||||
dist.barrier()
|
||||
|
|
@ -238,9 +219,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
|||
|
||||
def load_tensor(self, tensor: ShardedTensor) -> torch.Tensor:
|
||||
res = (
|
||||
torch.zeros(tensor.shape, device="cuda:0")
|
||||
if dist.get_rank() == 0
|
||||
else None
|
||||
torch.zeros(tensor.shape, device="cuda:0") if dist.get_rank() == 0 else None
|
||||
)
|
||||
tensor.gather(out=res)
|
||||
return res
|
||||
|
|
@ -335,9 +314,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
|||
state_dict_to_save = model_to_save.state_dict()
|
||||
|
||||
fs_writer = FileSystemWriter(path=path)
|
||||
save_state_dict(
|
||||
state_dict=state_dict_to_save, storage_writer=fs_writer
|
||||
)
|
||||
save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
|
|
@ -404,9 +381,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
|||
|
||||
fs_reader = FileSystemReader(path=path)
|
||||
|
||||
load_state_dict(
|
||||
state_dict=state_dict_to_load_to, storage_reader=fs_reader
|
||||
)
|
||||
load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader)
|
||||
|
||||
# We can't use torch.allclose since each ST has a different sharding spec
|
||||
store_tensor = self.load_tensor(model_to_save.sharded_tensor)
|
||||
|
|
@ -516,9 +491,7 @@ class TestDistributedReshardOnLoad(ShardedTensorTestBase):
|
|||
f"save-spec {save_spec} load-spec {load_spec}",
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
save_dict["replicated"], load_dict_replicated
|
||||
),
|
||||
torch.allclose(save_dict["replicated"], load_dict_replicated),
|
||||
f"save-spec {save_spec} load-spec {load_spec}",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,23 +1,23 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
|
||||
from torch.distributed.checkpoint.default_planner import (
|
||||
DefaultLoadPlanner,
|
||||
DefaultSavePlanner,
|
||||
)
|
||||
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
import torch.distributed as dist
|
||||
|
||||
from torch.distributed.checkpoint.default_planner import (
|
||||
DefaultSavePlanner,
|
||||
DefaultLoadPlanner,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,23 +9,12 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import torch.nn as nn
|
||||
from torch.distributed.checkpoint._fsspec_filesystem import (
|
||||
FsspecReader,
|
||||
FsspecWriter,
|
||||
)
|
||||
from torch.distributed.checkpoint.optimizer import (
|
||||
load_sharded_optimizer_state_dict,
|
||||
)
|
||||
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
|
||||
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
ShardedTensorTestBase,
|
||||
with_comms,
|
||||
|
|
@ -182,9 +171,7 @@ class TestFSSpecWithDist(ShardedTensorTestBase):
|
|||
return list(iter(opt.state.values()))[idx]
|
||||
|
||||
# Adam lazily creates its state
|
||||
self.assertEqual(
|
||||
opt_at(optim, 0)["exp_avg"], opt_at(optim_2, 0)["exp_avg"]
|
||||
)
|
||||
self.assertEqual(opt_at(optim, 0)["exp_avg"], opt_at(optim_2, 0)["exp_avg"])
|
||||
self.assertEqual(
|
||||
opt_at(optim, 0)["exp_avg_sq"], opt_at(optim_2, 0)["exp_avg_sq"]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.distributed.checkpoint._nested_dict import (
|
||||
flatten_state_dict,
|
||||
unflatten_state_dict,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestFlattening(TestCase):
|
||||
|
|
|
|||
|
|
@ -3,43 +3,45 @@
|
|||
import sys
|
||||
|
||||
import torch
|
||||
from torch.distributed.checkpoint.planner import LoadItemType, WriteItemType
|
||||
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
Shard,
|
||||
ShardMetadata,
|
||||
ShardedTensor,
|
||||
ShardedTensorMetadata,
|
||||
ShardMetadata,
|
||||
)
|
||||
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
|
||||
from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
run_tests,
|
||||
from torch.distributed.checkpoint.default_planner import (
|
||||
_create_default_local_metadata,
|
||||
create_default_global_save_plan,
|
||||
create_default_local_load_plan,
|
||||
create_default_local_save_plan,
|
||||
)
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
BytesStorageMetadata,
|
||||
ChunkStorageMetadata,
|
||||
MetadataIndex,
|
||||
TensorStorageMetadata,
|
||||
ChunkStorageMetadata,
|
||||
)
|
||||
from torch.distributed.checkpoint.planner import LoadItemType, WriteItemType
|
||||
|
||||
from torch.distributed.checkpoint.planner_helpers import (
|
||||
create_read_items_for_chunk_list,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
from torch.testing._internal.distributed.distributed_utils import (
|
||||
with_dist,
|
||||
with_fake_comms,
|
||||
with_dist
|
||||
)
|
||||
|
||||
from torch.distributed.checkpoint.default_planner import (
|
||||
create_default_global_save_plan,
|
||||
create_default_local_save_plan,
|
||||
create_default_local_load_plan,
|
||||
_create_default_local_metadata
|
||||
)
|
||||
|
||||
from torch.distributed.checkpoint.planner_helpers import create_read_items_for_chunk_list
|
||||
from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
|
||||
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
|
|
@ -48,30 +50,34 @@ if TEST_WITH_DEV_DBG_ASAN:
|
|||
)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def create_sharded_tensor(rank, world_size, shards_per_rank, shard_size=8):
|
||||
shards_metadata = []
|
||||
local_shards = []
|
||||
for idx in range(0, world_size * shards_per_rank):
|
||||
shard_rank = idx // shards_per_rank
|
||||
shard_md = ShardMetadata(shard_offsets=[idx * shard_size], shard_sizes=[shard_size], placement=f"rank:{shard_rank}/cpu")
|
||||
shard_md = ShardMetadata(
|
||||
shard_offsets=[idx * shard_size],
|
||||
shard_sizes=[shard_size],
|
||||
placement=f"rank:{shard_rank}/cpu",
|
||||
)
|
||||
shards_metadata.append(shard_md)
|
||||
if shard_rank == rank:
|
||||
shard = Shard.from_tensor_and_offsets(
|
||||
torch.rand(*shard_md.shard_sizes),
|
||||
shard_offsets=shard_md.shard_offsets,
|
||||
rank=rank
|
||||
rank=rank,
|
||||
)
|
||||
local_shards.append(shard)
|
||||
|
||||
sharded_tensor_md = ShardedTensorMetadata(
|
||||
shards_metadata=shards_metadata,
|
||||
size=torch.Size([shard_size * len(shards_metadata)]),
|
||||
tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1))
|
||||
tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1)),
|
||||
)
|
||||
|
||||
return ShardedTensor._init_from_local_shards_and_global_metadata(
|
||||
local_shards=local_shards,
|
||||
sharded_tensor_metadata=sharded_tensor_md
|
||||
local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -81,18 +87,17 @@ class TestSavePlan(TestCase):
|
|||
tensor = torch.rand(10)
|
||||
val = [1, 2, 3]
|
||||
st = create_sharded_tensor(rank=1, world_size=4, shards_per_rank=1)
|
||||
state_dict = {
|
||||
"tensor": tensor,
|
||||
"value": val,
|
||||
"st": st
|
||||
}
|
||||
state_dict = {"tensor": tensor, "value": val, "st": st}
|
||||
plan = create_default_local_save_plan(state_dict, False)
|
||||
self.assertEqual(2, len(plan.items))
|
||||
wi = plan.items[0]
|
||||
self.assertEqual(wi.index, MetadataIndex("tensor", [0]))
|
||||
self.assertEqual(wi.type, WriteItemType.TENSOR)
|
||||
self.assertEqual(wi.tensor_data.size, tensor.size())
|
||||
self.assertEqual(wi.tensor_data.properties, TensorProperties.create_from_tensor(torch.zeros(1)))
|
||||
self.assertEqual(
|
||||
wi.tensor_data.properties,
|
||||
TensorProperties.create_from_tensor(torch.zeros(1)),
|
||||
)
|
||||
self.assertEqual(wi.tensor_data.chunk.offsets, torch.Size([0]))
|
||||
self.assertEqual(wi.tensor_data.chunk.sizes, torch.Size([10]))
|
||||
|
||||
|
|
@ -100,7 +105,10 @@ class TestSavePlan(TestCase):
|
|||
self.assertEqual(st_wi.index, MetadataIndex("st", [8]))
|
||||
self.assertEqual(st_wi.type, WriteItemType.SHARD)
|
||||
self.assertEqual(st_wi.tensor_data.size, st.size())
|
||||
self.assertEqual(st_wi.tensor_data.properties, TensorProperties.create_from_tensor(torch.zeros(1)))
|
||||
self.assertEqual(
|
||||
st_wi.tensor_data.properties,
|
||||
TensorProperties.create_from_tensor(torch.zeros(1)),
|
||||
)
|
||||
self.assertEqual(st_wi.tensor_data.chunk.offsets, torch.Size([8]))
|
||||
self.assertEqual(st_wi.tensor_data.chunk.sizes, torch.Size([8]))
|
||||
|
||||
|
|
@ -111,7 +119,10 @@ class TestSavePlan(TestCase):
|
|||
tensor_wi = next(wi for wi in plan.items if wi.type == WriteItemType.TENSOR)
|
||||
self.assertEqual(tensor_wi.index, MetadataIndex("tensor", [0]))
|
||||
self.assertEqual(tensor_wi.tensor_data.size, tensor.size())
|
||||
self.assertEqual(tensor_wi.tensor_data.properties, TensorProperties.create_from_tensor(tensor))
|
||||
self.assertEqual(
|
||||
tensor_wi.tensor_data.properties,
|
||||
TensorProperties.create_from_tensor(tensor),
|
||||
)
|
||||
self.assertEqual(tensor_wi.tensor_data.chunk.offsets, torch.Size([0]))
|
||||
self.assertEqual(tensor_wi.tensor_data.chunk.sizes, torch.Size([10]))
|
||||
|
||||
|
|
@ -125,11 +136,7 @@ class TestSavePlan(TestCase):
|
|||
tensor = torch.rand(10)
|
||||
val = [1, 2, 3]
|
||||
st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1)
|
||||
state_dict = {
|
||||
"tensor": tensor,
|
||||
"value": val,
|
||||
"st": st
|
||||
}
|
||||
state_dict = {"tensor": tensor, "value": val, "st": st}
|
||||
return create_default_local_save_plan(state_dict, rank == 0)
|
||||
|
||||
all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
|
||||
|
|
@ -150,11 +157,15 @@ class TestSavePlan(TestCase):
|
|||
else:
|
||||
self.assertTrue(isinstance(item_md, TensorStorageMetadata))
|
||||
self.assertEqual(item_md.size, old_item.tensor_data.size)
|
||||
self.assertEqual(item_md.properties, old_item.tensor_data.properties)
|
||||
self.assertEqual(
|
||||
item_md.properties, old_item.tensor_data.properties
|
||||
)
|
||||
|
||||
self.assertIsNotNone(new_item.index.index)
|
||||
# Make sure the hint is correct
|
||||
self.assertEqual(item_md.chunks[new_item.index.index], old_item.tensor_data.chunk)
|
||||
self.assertEqual(
|
||||
item_md.chunks[new_item.index.index], old_item.tensor_data.chunk
|
||||
)
|
||||
|
||||
def test_local_load_plan(self):
|
||||
def create_state_dict(rank):
|
||||
|
|
@ -162,11 +173,7 @@ class TestSavePlan(TestCase):
|
|||
tensor = torch.rand(10)
|
||||
val = [1, 2, 3]
|
||||
st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1)
|
||||
return {
|
||||
"tensor": tensor,
|
||||
"value": val,
|
||||
"st": st
|
||||
}
|
||||
return {"tensor": tensor, "value": val, "st": st}
|
||||
|
||||
state_dict = create_state_dict(1)
|
||||
metadata = _create_default_local_metadata(state_dict)
|
||||
|
|
@ -175,7 +182,9 @@ class TestSavePlan(TestCase):
|
|||
# This will create 3 entries
|
||||
self.assertEqual(3, len(load_plan.items))
|
||||
st_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "st")
|
||||
tensor_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "tensor")
|
||||
tensor_item = next(
|
||||
ri for ri in load_plan.items if ri.dest_index.fqn == "tensor"
|
||||
)
|
||||
bytes_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "value")
|
||||
|
||||
self.assertEqual(st_item.type, LoadItemType.TENSOR)
|
||||
|
|
@ -208,7 +217,6 @@ class TestSavePlan(TestCase):
|
|||
)
|
||||
}
|
||||
|
||||
|
||||
# Rank 1 has a 16 bytes shard from [16, 32[
|
||||
world8_state_dict = create_state_dict(rank=1, world_size=8)
|
||||
world8_metadata = _create_default_local_metadata(world8_state_dict)
|
||||
|
|
@ -221,8 +229,12 @@ class TestSavePlan(TestCase):
|
|||
# Each 4-world shard has 32 elements, so it needs to load 2 shards
|
||||
load_plan = create_default_local_load_plan(world4_state_dict, world8_metadata)
|
||||
self.assertEqual(2, len(load_plan.items))
|
||||
low_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0]))
|
||||
high_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([16]))
|
||||
low_ri = next(
|
||||
ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0])
|
||||
)
|
||||
high_ri = next(
|
||||
ri for ri in load_plan.items if ri.dest_offsets == torch.Size([16])
|
||||
)
|
||||
|
||||
self.assertEqual(low_ri.storage_index, MetadataIndex("st", [32]))
|
||||
self.assertEqual(low_ri.storage_offsets, torch.Size([0]))
|
||||
|
|
@ -258,6 +270,7 @@ class TestSavePlan(TestCase):
|
|||
shard_size=120 // world_size,
|
||||
)
|
||||
}
|
||||
|
||||
# rank 1 has a 30 bytes shard from [30, 60[
|
||||
world4_state_dict = create_state_dict(rank=1, world_size=4)
|
||||
world4_metadata = _create_default_local_metadata(world4_state_dict)
|
||||
|
|
@ -268,9 +281,13 @@ class TestSavePlan(TestCase):
|
|||
load_plan = create_default_local_load_plan(world3_state_dict, world4_metadata)
|
||||
self.assertEqual(2, len(load_plan.items))
|
||||
# this is [30, 60] to load [40, 60]
|
||||
low_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0]))
|
||||
low_ri = next(
|
||||
ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0])
|
||||
)
|
||||
# this is [60, 90] to load [60, 80]
|
||||
high_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([20]))
|
||||
high_ri = next(
|
||||
ri for ri in load_plan.items if ri.dest_offsets == torch.Size([20])
|
||||
)
|
||||
|
||||
self.assertEqual(low_ri.storage_index, MetadataIndex("st", [30]))
|
||||
self.assertEqual(low_ri.storage_offsets, torch.Size([10]))
|
||||
|
|
@ -284,27 +301,19 @@ class TestSavePlan(TestCase):
|
|||
self.assertEqual(high_ri.dest_offsets, torch.Size([20]))
|
||||
self.assertEqual(high_ri.lengths, torch.Size([20]))
|
||||
|
||||
|
||||
class TestPlannerHelpers(TestCase):
|
||||
def test_create_read_item_from_chunks(self):
|
||||
tensor_md = TensorStorageMetadata(
|
||||
properties=TensorProperties.create_from_tensor(torch.empty([16])),
|
||||
size=torch.Size([16]),
|
||||
chunks=[
|
||||
ChunkStorageMetadata(
|
||||
offsets=torch.Size([0]),
|
||||
sizes=torch.Size([8])
|
||||
),
|
||||
ChunkStorageMetadata(
|
||||
offsets=torch.Size([8]),
|
||||
sizes=torch.Size([8])
|
||||
)
|
||||
]
|
||||
ChunkStorageMetadata(offsets=torch.Size([0]), sizes=torch.Size([8])),
|
||||
ChunkStorageMetadata(offsets=torch.Size([8]), sizes=torch.Size([8])),
|
||||
],
|
||||
)
|
||||
|
||||
chunk = ChunkStorageMetadata(
|
||||
offsets=torch.Size([4]),
|
||||
sizes=torch.Size([7])
|
||||
)
|
||||
chunk = ChunkStorageMetadata(offsets=torch.Size([4]), sizes=torch.Size([7]))
|
||||
read_items = create_read_items_for_chunk_list("foo", tensor_md, [chunk])
|
||||
|
||||
self.assertEqual(2, len(read_items))
|
||||
|
|
@ -316,7 +325,6 @@ class TestPlannerHelpers(TestCase):
|
|||
|
||||
self.assertEqual(torch.Size([4]), read_items[0].lengths)
|
||||
|
||||
|
||||
self.assertEqual(MetadataIndex("foo", [4]), read_items[1].dest_index)
|
||||
self.assertEqual(torch.Size([4]), read_items[1].dest_offsets)
|
||||
|
||||
|
|
@ -325,5 +333,6 @@ class TestPlannerHelpers(TestCase):
|
|||
|
||||
self.assertEqual(torch.Size([3]), read_items[1].lengths)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
import torch.distributed.checkpoint._traverse as _traverse
|
||||
|
|
@ -95,9 +96,7 @@ class TestTraverse(TestCase):
|
|||
self.assertEqual(data[("key0", "key2")], torch.tensor([1]))
|
||||
|
||||
def test_traverse_doesnt_ignore_intermediate_collections(self) -> None:
|
||||
state_dict: STATE_DICT_TYPE = {
|
||||
"key0": [{"key1": {"key2": torch.tensor([1])}}]
|
||||
}
|
||||
state_dict: STATE_DICT_TYPE = {"key0": [{"key1": {"key2": torch.tensor([1])}}]}
|
||||
|
||||
data = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,22 +6,20 @@ import torch
|
|||
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
Shard,
|
||||
ShardMetadata,
|
||||
ShardedTensor,
|
||||
ShardedTensorMetadata,
|
||||
ShardMetadata,
|
||||
)
|
||||
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
|
||||
from torch.distributed.checkpoint.metadata import MetadataIndex
|
||||
from torch.distributed.checkpoint.utils import find_state_dict_object
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
run_tests,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
TestCase,
|
||||
)
|
||||
from torch.distributed.checkpoint.utils import find_state_dict_object
|
||||
from torch.distributed.checkpoint.metadata import MetadataIndex
|
||||
from torch.testing._internal.distributed.distributed_utils import (
|
||||
with_fake_comms
|
||||
)
|
||||
from torch.testing._internal.distributed.distributed_utils import with_fake_comms
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
|
|
@ -30,30 +28,32 @@ if TEST_WITH_DEV_DBG_ASAN:
|
|||
)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def create_sharded_tensor(rank, world_size, shards_per_rank):
|
||||
shards_metadata = []
|
||||
local_shards = []
|
||||
for idx in range(0, world_size * shards_per_rank):
|
||||
shard_rank = idx // shards_per_rank
|
||||
shard_md = ShardMetadata(shard_offsets=[idx * 8], shard_sizes=[8], placement=f"rank:{shard_rank}/cpu")
|
||||
shard_md = ShardMetadata(
|
||||
shard_offsets=[idx * 8], shard_sizes=[8], placement=f"rank:{shard_rank}/cpu"
|
||||
)
|
||||
shards_metadata.append(shard_md)
|
||||
if shard_rank == rank:
|
||||
shard = Shard.from_tensor_and_offsets(
|
||||
torch.rand(*shard_md.shard_sizes),
|
||||
shard_offsets=shard_md.shard_offsets,
|
||||
rank=rank
|
||||
rank=rank,
|
||||
)
|
||||
local_shards.append(shard)
|
||||
|
||||
sharded_tensor_md = ShardedTensorMetadata(
|
||||
shards_metadata=shards_metadata,
|
||||
size=torch.Size([8 * len(shards_metadata)]),
|
||||
tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1))
|
||||
tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1)),
|
||||
)
|
||||
|
||||
return ShardedTensor._init_from_local_shards_and_global_metadata(
|
||||
local_shards=local_shards,
|
||||
sharded_tensor_metadata=sharded_tensor_md
|
||||
local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,23 +1,15 @@
|
|||
from .api import CheckpointException
|
||||
from .checkpointer import Checkpointer
|
||||
from .default_planner import DefaultLoadPlanner, DefaultSavePlanner
|
||||
from .filesystem import FileSystemCheckpointer, FileSystemReader, FileSystemWriter
|
||||
from .metadata import (
|
||||
TensorStorageMetadata,
|
||||
BytesStorageMetadata,
|
||||
ChunkStorageMetadata,
|
||||
Metadata,
|
||||
TensorStorageMetadata,
|
||||
)
|
||||
from .state_dict_loader import load_state_dict, load
|
||||
from .state_dict_saver import save_state_dict, save
|
||||
from .storage import StorageReader, StorageWriter
|
||||
from .checkpointer import Checkpointer
|
||||
from .filesystem import FileSystemReader, FileSystemWriter, FileSystemCheckpointer
|
||||
from .api import CheckpointException
|
||||
|
||||
from .planner import (
|
||||
SavePlanner,
|
||||
LoadPlanner,
|
||||
SavePlan,
|
||||
LoadPlan,
|
||||
ReadItem,
|
||||
WriteItem,
|
||||
)
|
||||
from .default_planner import DefaultSavePlanner, DefaultLoadPlanner
|
||||
from .optimizer import load_sharded_optimizer_state_dict
|
||||
from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem
|
||||
from .state_dict_loader import load, load_state_dict
|
||||
from .state_dict_saver import save, save_state_dict
|
||||
from .storage import StorageReader, StorageWriter
|
||||
|
|
|
|||
|
|
@ -23,8 +23,10 @@ def init_logger() -> logging.Logger:
|
|||
logger.propagate = False
|
||||
return logger
|
||||
|
||||
|
||||
logger = init_logger()
|
||||
|
||||
|
||||
# TODO add docstring for dedup_tensors
|
||||
def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
|
||||
all_plans = list(all_plans)
|
||||
|
|
@ -51,8 +53,6 @@ def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
|
|||
for write_item in all_plans[plan_idx].items
|
||||
if write_item.index not in key_set
|
||||
]
|
||||
all_plans[plan_idx] = dataclasses.replace(
|
||||
all_plans[plan_idx], items=new_items
|
||||
)
|
||||
all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items)
|
||||
|
||||
return all_plans
|
||||
|
|
|
|||
|
|
@ -9,20 +9,18 @@ import pickle
|
|||
import queue
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, cast, Dict, List, Optional, Union
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
from fsspec import AbstractFileSystem
|
||||
from fsspec.core import url_to_fs
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._utils import _get_device_module
|
||||
|
||||
from torch.distributed._shard._utils import narrow_tensor_by_index
|
||||
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
|
||||
|
||||
from torch.distributed.checkpoint.planner import (
|
||||
LoadItemType,
|
||||
LoadPlan,
|
||||
|
|
@ -145,10 +143,7 @@ class _OverlappingCpuLoader(_TensorLoader):
|
|||
|
||||
def _refill(self):
|
||||
with self.device_module.stream(self.stream):
|
||||
while (
|
||||
not self._done
|
||||
and self.in_flight_data < self.inflight_threshhold
|
||||
):
|
||||
while not self._done and self.in_flight_data < self.inflight_threshhold:
|
||||
_, obj = self.items[self.idx]
|
||||
self.idx += 1
|
||||
tensor = self.resolve_fun(obj).detach()
|
||||
|
|
@ -206,9 +201,7 @@ def _item_size(item: WriteItem) -> int:
|
|||
return size * torch._utils._element_size(dtype)
|
||||
|
||||
|
||||
def _split_by_size_and_type(
|
||||
bins: int, items: List[WriteItem]
|
||||
) -> List[List[WriteItem]]:
|
||||
def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
|
||||
if bins == 1:
|
||||
return [items]
|
||||
|
||||
|
|
@ -276,16 +269,12 @@ def _write_files_from_queue(
|
|||
planner.resolve_data,
|
||||
)
|
||||
|
||||
tensor_w = [
|
||||
wi for wi in write_items if wi.type != WriteItemType.BYTE_IO
|
||||
]
|
||||
tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO]
|
||||
for write_item in tensor_w:
|
||||
loader.add(_item_size(write_item), write_item)
|
||||
loader.start_loading()
|
||||
|
||||
bytes_w = [
|
||||
wi for wi in write_items if wi.type == WriteItemType.BYTE_IO
|
||||
]
|
||||
bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO]
|
||||
write_results = []
|
||||
|
||||
with fs.transaction:
|
||||
|
|
@ -351,9 +340,7 @@ class FsspecWriter(StorageWriter):
|
|||
self.fs.makedirs(self.path, exist_ok=True)
|
||||
return plan
|
||||
|
||||
def prepare_global_plan(
|
||||
self, global_plan: List[SavePlan]
|
||||
) -> List[SavePlan]:
|
||||
def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]:
|
||||
new_plans = [
|
||||
dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_"))
|
||||
for i, plan in enumerate(global_plan)
|
||||
|
|
@ -376,9 +363,7 @@ class FsspecWriter(StorageWriter):
|
|||
|
||||
file_queue: queue.Queue = queue.Queue()
|
||||
if self.single_file_per_rank:
|
||||
for bucket in _split_by_size_and_type(
|
||||
self.thread_count, plan.items
|
||||
):
|
||||
for bucket in _split_by_size_and_type(self.thread_count, plan.items):
|
||||
file_name = gen_file()
|
||||
file_path = os.path.join(self.path, file_name)
|
||||
file_queue.put((file_path, file_name, bucket))
|
||||
|
|
@ -427,9 +412,7 @@ class FsspecWriter(StorageWriter):
|
|||
fut.set_result(res)
|
||||
return fut
|
||||
|
||||
def finish(
|
||||
self, metadata: Metadata, results: List[List[WriteResult]]
|
||||
) -> None:
|
||||
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
|
||||
storage_md = dict()
|
||||
for wr_list in results:
|
||||
storage_md.update({wr.index: wr.storage_data for wr in wr_list})
|
||||
|
|
@ -495,16 +478,12 @@ class FsspecReader(StorageReader):
|
|||
with fsspec.open(metadata_path, "rb") as metadata_file:
|
||||
return pickle.load(metadata_file)
|
||||
|
||||
def set_up_storage_reader(
|
||||
self, metadata: Metadata, is_coordinator: bool
|
||||
) -> None:
|
||||
def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
|
||||
self.storage_data = metadata.storage_data
|
||||
assert self.storage_data is not None
|
||||
|
||||
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
|
||||
return plan
|
||||
|
||||
def prepare_global_plan(
|
||||
self, global_plan: List[LoadPlan]
|
||||
) -> List[LoadPlan]:
|
||||
def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
|
||||
return global_plan
|
||||
|
|
|
|||
|
|
@ -1,16 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
STATE_DICT_TYPE,
|
||||
)
|
||||
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
|
||||
|
||||
from ._traverse import (
|
||||
traverse_state_dict,
|
||||
set_element,
|
||||
OBJ_PATH,
|
||||
STATE_DICT_ITEM,
|
||||
)
|
||||
from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict
|
||||
|
||||
"""
|
||||
TODO:
|
||||
|
|
|
|||
|
|
@ -3,29 +3,12 @@
|
|||
import copy
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._shard.sharded_tensor import Shard, ShardedTensor, ShardMetadata
|
||||
from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata
|
||||
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
|
||||
from torch.distributed.remote_device import _remote_device
|
||||
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
STATE_DICT_TYPE,
|
||||
)
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
Shard,
|
||||
ShardMetadata,
|
||||
ShardedTensor,
|
||||
)
|
||||
|
||||
from torch.distributed._shard.sharded_tensor.metadata import (
|
||||
ShardedTensorMetadata,
|
||||
)
|
||||
|
||||
|
||||
from ._traverse import (
|
||||
OBJ_PATH,
|
||||
traverse_state_dict,
|
||||
set_element,
|
||||
STATE_DICT_ITEM,
|
||||
)
|
||||
|
||||
from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict
|
||||
from .utils import _element_wise_add, _normalize_device_info
|
||||
|
||||
|
||||
|
|
@ -62,9 +45,7 @@ def _flatten_sharded_tensors(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
|
|||
return
|
||||
|
||||
if len(inner_st.local_shards()) != 1:
|
||||
raise ValueError(
|
||||
"Cannot handle inner tensor with more than 1 shard"
|
||||
)
|
||||
raise ValueError("Cannot handle inner tensor with more than 1 shard")
|
||||
inner_shard = inner_st.local_shards()[0]
|
||||
|
||||
local_shards = [
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import torch
|
||||
|
||||
from typing import (
|
||||
Callable,
|
||||
cast,
|
||||
Collection,
|
||||
List,
|
||||
Mapping,
|
||||
|
|
@ -11,13 +10,12 @@ from typing import (
|
|||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
STATE_DICT_TYPE,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
|
||||
|
||||
PATH_ITEM = Union[str, int]
|
||||
OBJ_PATH = Tuple[PATH_ITEM, ...]
|
||||
|
|
@ -47,6 +45,7 @@ def traverse_state_dict(
|
|||
By default, all collections with at least one ``torch.Tensor`` element are traversed.
|
||||
Visitor takes a path argument that is a tuple of the keys used to reach it.
|
||||
"""
|
||||
|
||||
# a value is terminal if it has no other containers values inside it
|
||||
def _is_terminal(value: STATE_DICT_ITEM) -> bool:
|
||||
values: Collection[STATE_DICT_ITEM]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Dict, Tuple, Any
|
||||
import traceback as tb
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary]
|
||||
|
||||
|
|
@ -15,9 +15,7 @@ def _is_wrapped_exception(obj: Any) -> bool:
|
|||
return False
|
||||
if len(obj) != 2:
|
||||
return False
|
||||
return isinstance(obj[0], BaseException) and isinstance(
|
||||
obj[1], tb.StackSummary
|
||||
)
|
||||
return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary)
|
||||
|
||||
|
||||
class CheckpointException(BaseException):
|
||||
|
|
|
|||
|
|
@ -6,50 +6,42 @@ import logging
|
|||
import operator
|
||||
from collections import ChainMap
|
||||
from functools import reduce
|
||||
from typing import List, Tuple, Dict, Any, Union, cast
|
||||
from typing import Any, cast, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from torch.distributed._shard._utils import narrow_tensor_by_index
|
||||
from torch.distributed._tensor import DTensor
|
||||
|
||||
|
||||
from torch.distributed.checkpoint.planner import (
|
||||
SavePlanner,
|
||||
LoadPlanner,
|
||||
SavePlan,
|
||||
LoadPlan,
|
||||
ReadItem,
|
||||
WriteItem,
|
||||
WriteItemType,
|
||||
)
|
||||
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
BytesStorageMetadata,
|
||||
ChunkStorageMetadata,
|
||||
TensorStorageMetadata,
|
||||
MetadataIndex,
|
||||
Metadata,
|
||||
STATE_DICT_TYPE,
|
||||
STORAGE_TYPES,
|
||||
)
|
||||
|
||||
from torch.distributed.checkpoint.planner_helpers import (
|
||||
_create_read_items,
|
||||
_create_write_items,
|
||||
_create_default_metadata_only_plan,
|
||||
)
|
||||
|
||||
from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
|
||||
from torch.distributed.checkpoint._nested_dict import (
|
||||
FLATTEN_MAPPING,
|
||||
flatten_state_dict,
|
||||
)
|
||||
from torch.distributed.checkpoint._sharded_tensor_utils import (
|
||||
_flatten_sharded_tensors,
|
||||
)
|
||||
from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
|
||||
from torch.distributed.checkpoint.utils import find_state_dict_object
|
||||
from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors
|
||||
from torch.distributed.checkpoint._traverse import set_element
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
BytesStorageMetadata,
|
||||
ChunkStorageMetadata,
|
||||
Metadata,
|
||||
MetadataIndex,
|
||||
STATE_DICT_TYPE,
|
||||
STORAGE_TYPES,
|
||||
TensorStorageMetadata,
|
||||
)
|
||||
from torch.distributed.checkpoint.planner import (
|
||||
LoadPlan,
|
||||
LoadPlanner,
|
||||
ReadItem,
|
||||
SavePlan,
|
||||
SavePlanner,
|
||||
WriteItem,
|
||||
WriteItemType,
|
||||
)
|
||||
from torch.distributed.checkpoint.planner_helpers import (
|
||||
_create_default_metadata_only_plan,
|
||||
_create_read_items,
|
||||
_create_write_items,
|
||||
)
|
||||
from torch.distributed.checkpoint.utils import find_state_dict_object
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -79,9 +71,7 @@ class DefaultSavePlanner(SavePlanner):
|
|||
self.dedup_replicated_tensors = dedup_replicated_tensors
|
||||
self.mappings = {}
|
||||
|
||||
def set_up_planner(
|
||||
self, state_dict: STATE_DICT_TYPE, is_coordinator: bool
|
||||
) -> None:
|
||||
def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None:
|
||||
if self.flatten_state_dict:
|
||||
state_dict, self.mappings = flatten_state_dict(state_dict)
|
||||
if self.flatten_sharded_tensors:
|
||||
|
|
@ -90,9 +80,7 @@ class DefaultSavePlanner(SavePlanner):
|
|||
self.is_coordinator = is_coordinator
|
||||
|
||||
def create_local_plan(self) -> SavePlan:
|
||||
plan = create_default_local_save_plan(
|
||||
self.state_dict, self.is_coordinator
|
||||
)
|
||||
plan = create_default_local_save_plan(self.state_dict, self.is_coordinator)
|
||||
if self.flatten_state_dict:
|
||||
plan = dataclasses.replace(plan, planner_data=self.mappings)
|
||||
self.plan = plan
|
||||
|
|
@ -114,9 +102,7 @@ class DefaultSavePlanner(SavePlanner):
|
|||
# )
|
||||
planner_data_dict = [p.planner_data for p in global_plan]
|
||||
merged_mappings = dict(ChainMap(*planner_data_dict))
|
||||
metadata = dataclasses.replace(
|
||||
metadata, planner_data=merged_mappings
|
||||
)
|
||||
metadata = dataclasses.replace(metadata, planner_data=merged_mappings)
|
||||
|
||||
if not _validate_global_plan(global_plan, metadata):
|
||||
raise ValueError("Failed to validate global plan")
|
||||
|
|
@ -130,9 +116,7 @@ class DefaultSavePlanner(SavePlanner):
|
|||
self.plan = new_plan
|
||||
return new_plan
|
||||
|
||||
def resolve_data(
|
||||
self, write_item: WriteItem
|
||||
) -> Union[torch.Tensor, io.BytesIO]:
|
||||
def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
|
||||
object = self.lookup_object(write_item.index)
|
||||
return self.transform_object(write_item, object)
|
||||
|
||||
|
|
@ -222,9 +206,7 @@ class DefaultLoadPlanner(LoadPlanner):
|
|||
|
||||
def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor):
|
||||
"""Extension from the planner interface to make it easy to extend the default planner."""
|
||||
return narrow_tensor_by_index(
|
||||
tensor, read_item.dest_offsets, read_item.lengths
|
||||
)
|
||||
return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths)
|
||||
|
||||
|
||||
def create_default_local_load_plan(
|
||||
|
|
@ -353,9 +335,7 @@ def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata:
|
|||
return md
|
||||
|
||||
|
||||
def _check_box_overlap(
|
||||
box0: ChunkStorageMetadata, box1: ChunkStorageMetadata
|
||||
) -> bool:
|
||||
def _check_box_overlap(box0: ChunkStorageMetadata, box1: ChunkStorageMetadata) -> bool:
|
||||
"""Check if two boxes overlap. Tuples are (offset, lengths)."""
|
||||
# For each dim of each shard, check if one shard resides on the other
|
||||
# end of second shard with respect to that dim. As an example for a 2D
|
||||
|
|
@ -385,9 +365,7 @@ def _check_box_bounds(
|
|||
return True
|
||||
|
||||
|
||||
def _validate_global_plan(
|
||||
global_plan: List[SavePlan], metadata: Metadata
|
||||
) -> bool:
|
||||
def _validate_global_plan(global_plan: List[SavePlan], metadata: Metadata) -> bool:
|
||||
all_good = True
|
||||
for key, value in metadata.state_dict_metadata.items():
|
||||
if isinstance(value, BytesStorageMetadata):
|
||||
|
|
@ -402,7 +380,10 @@ def _validate_global_plan(
|
|||
"""
|
||||
key:%s has out of bounds chunk:
|
||||
tensor-size:%s chunk: %s
|
||||
""", key, value.size, chunk0
|
||||
""",
|
||||
key,
|
||||
value.size,
|
||||
chunk0,
|
||||
)
|
||||
all_good = False
|
||||
chunks_volume += reduce(operator.mul, chunk0.sizes, 1)
|
||||
|
|
@ -422,7 +403,10 @@ def _validate_global_plan(
|
|||
"""
|
||||
key:%s invalid fill tensor-volume:
|
||||
%s chunks-volume: %s
|
||||
""", key, tensor_volume, chunks_volume
|
||||
""",
|
||||
key,
|
||||
tensor_volume,
|
||||
chunks_volume,
|
||||
)
|
||||
all_good = False
|
||||
|
||||
|
|
|
|||
|
|
@ -14,12 +14,10 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
import torch.multiprocessing as mp
|
||||
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
|
||||
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
|
||||
from torch.distributed.checkpoint.optimizer import (
|
||||
load_sharded_optimizer_state_dict,
|
||||
)
|
||||
|
||||
CHECKPOINT_DIR = f"/scratch/{os.environ['LOGNAME']}/checkpoint"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,53 +1,38 @@
|
|||
from abc import ABC, abstractmethod
|
||||
import queue
|
||||
import threading
|
||||
import collections
|
||||
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import dataclasses
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
from typing import Optional, List, Union, Dict, cast
|
||||
import queue
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import cast, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch._utils import _get_device_module
|
||||
from torch.distributed._shard._utils import narrow_tensor_by_index
|
||||
from torch.distributed.checkpoint.checkpointer import Checkpointer
|
||||
from torch.futures import Future
|
||||
from pathlib import Path
|
||||
|
||||
from .metadata import (
|
||||
Metadata,
|
||||
MetadataIndex,
|
||||
)
|
||||
from .storage import (
|
||||
StorageReader,
|
||||
StorageWriter,
|
||||
WriteResult,
|
||||
)
|
||||
|
||||
from .metadata import Metadata, MetadataIndex
|
||||
from .planner import (
|
||||
LoadItemType,
|
||||
LoadPlanner,
|
||||
LoadPlan,
|
||||
LoadPlanner,
|
||||
ReadItem,
|
||||
SavePlan,
|
||||
SavePlanner,
|
||||
ReadItem,
|
||||
WriteItem,
|
||||
WriteItemType,
|
||||
)
|
||||
|
||||
from .storage import StorageReader, StorageWriter, WriteResult
|
||||
from .utils import _create_file_view
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.checkpoint.checkpointer import Checkpointer
|
||||
from torch.distributed._shard._utils import narrow_tensor_by_index
|
||||
from torch._utils import _get_device_module
|
||||
|
||||
__all__ = [
|
||||
"FileSystemWriter",
|
||||
"FileSystemReader",
|
||||
"FileSystemCheckpointer"
|
||||
]
|
||||
__all__ = ["FileSystemWriter", "FileSystemReader", "FileSystemCheckpointer"]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -150,10 +135,7 @@ class _OverlappingCpuLoader(_TensorLoader):
|
|||
|
||||
def _refill(self):
|
||||
with self.device_module.stream(self.stream):
|
||||
while (
|
||||
not self._done
|
||||
and self.in_flight_data < self.inflight_threshhold
|
||||
):
|
||||
while not self._done and self.in_flight_data < self.inflight_threshhold:
|
||||
_, obj = self.items[self.idx]
|
||||
self.idx += 1
|
||||
tensor = self.resolve_fun(obj).detach()
|
||||
|
|
@ -211,9 +193,7 @@ def _item_size(item: WriteItem) -> int:
|
|||
return size * torch._utils._element_size(dtype)
|
||||
|
||||
|
||||
def _split_by_size_and_type(
|
||||
bins, items: List[WriteItem]
|
||||
) -> List[List[WriteItem]]:
|
||||
def _split_by_size_and_type(bins, items: List[WriteItem]) -> List[List[WriteItem]]:
|
||||
if bins == 1:
|
||||
return [items]
|
||||
|
||||
|
|
@ -276,16 +256,12 @@ def _write_files_from_queue(
|
|||
planner.resolve_data,
|
||||
)
|
||||
|
||||
tensor_w = [
|
||||
wi for wi in write_items if wi.type != WriteItemType.BYTE_IO
|
||||
]
|
||||
tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO]
|
||||
for write_item in tensor_w:
|
||||
loader.add(_item_size(write_item), write_item)
|
||||
loader.start_loading()
|
||||
|
||||
bytes_w = [
|
||||
wi for wi in write_items if wi.type == WriteItemType.BYTE_IO
|
||||
]
|
||||
bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO]
|
||||
write_results = []
|
||||
|
||||
with file_name.open("wb") as stream:
|
||||
|
|
@ -358,9 +334,7 @@ class FileSystemWriter(StorageWriter):
|
|||
self.path.mkdir(parents=True, exist_ok=True)
|
||||
return plan
|
||||
|
||||
def prepare_global_plan(
|
||||
self, global_plan: List[SavePlan]
|
||||
) -> List[SavePlan]:
|
||||
def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]:
|
||||
new_plans = [
|
||||
dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_"))
|
||||
for i, plan in enumerate(global_plan)
|
||||
|
|
@ -383,9 +357,7 @@ class FileSystemWriter(StorageWriter):
|
|||
|
||||
file_queue: queue.Queue = queue.Queue()
|
||||
if self.single_file_per_rank:
|
||||
for bucket in _split_by_size_and_type(
|
||||
self.thread_count, plan.items
|
||||
):
|
||||
for bucket in _split_by_size_and_type(self.thread_count, plan.items):
|
||||
file_name = gen_file()
|
||||
file_queue.put((self.path / file_name, file_name, bucket))
|
||||
else:
|
||||
|
|
@ -432,9 +404,7 @@ class FileSystemWriter(StorageWriter):
|
|||
fut.set_result(res)
|
||||
return fut
|
||||
|
||||
def finish(
|
||||
self, metadata: Metadata, results: List[List[WriteResult]]
|
||||
) -> None:
|
||||
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
|
||||
storage_md = dict()
|
||||
for wr_list in results:
|
||||
storage_md.update({wr.index: wr.storage_data for wr in wr_list})
|
||||
|
|
@ -507,11 +477,10 @@ class FileSystemReader(StorageReader):
|
|||
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
|
||||
return plan
|
||||
|
||||
def prepare_global_plan(
|
||||
self, global_plan: List[LoadPlan]
|
||||
) -> List[LoadPlan]:
|
||||
def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
|
||||
return global_plan
|
||||
|
||||
|
||||
class FileSystemCheckpointer(Checkpointer):
|
||||
"""An implementation of :py:class:`torch.distributed.checkpoint.checkpointer.Checkpointer`
|
||||
for the file system. Wraps the creation and usage of ``FileSystemWriter`` and ``FileSystemReader``.
|
||||
|
|
@ -547,11 +516,7 @@ class FileSystemCheckpointer(Checkpointer):
|
|||
"""
|
||||
|
||||
storage_writer = FileSystemWriter(
|
||||
path,
|
||||
single_file_per_rank,
|
||||
sync_files,
|
||||
thread_count,
|
||||
per_thread_copy_ahead
|
||||
path, single_file_per_rank, sync_files, thread_count, per_thread_copy_ahead
|
||||
)
|
||||
storage_reader = FileSystemReader(path)
|
||||
|
||||
|
|
@ -562,5 +527,5 @@ class FileSystemCheckpointer(Checkpointer):
|
|||
coordinator_rank=coordinator_rank,
|
||||
no_dist=no_dist,
|
||||
load_planner=load_planner,
|
||||
save_planner=save_planner
|
||||
save_planner=save_planner,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,10 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Union, Optional, Sequence, Any
|
||||
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
|
||||
from torch.distributed.checkpoint.stateful import StatefulT
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
ShardedTensor,
|
||||
)
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
|
||||
from torch.distributed.checkpoint.stateful import StatefulT
|
||||
|
||||
__all__ = [
|
||||
"ChunkStorageMetadata",
|
||||
|
|
|
|||
|
|
@ -1,49 +1,41 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
import dataclasses
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union, cast
|
||||
from torch.distributed.checkpoint.planner import LoadPlan
|
||||
from typing import cast, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _get_device_module
|
||||
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
||||
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
|
||||
from torch.distributed._shard.sharded_tensor.shard import Shard
|
||||
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import (
|
||||
ChunkShardingSpec,
|
||||
)
|
||||
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
|
||||
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
BytesStorageMetadata,
|
||||
ChunkStorageMetadata,
|
||||
Metadata,
|
||||
MetadataIndex,
|
||||
STATE_DICT_TYPE,
|
||||
TensorStorageMetadata,
|
||||
ChunkStorageMetadata,
|
||||
)
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
|
||||
from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner
|
||||
from torch.distributed.checkpoint.planner_helpers import (
|
||||
create_read_items_for_chunk_list,
|
||||
_create_read_items,
|
||||
create_read_items_for_chunk_list,
|
||||
)
|
||||
from torch.distributed.remote_device import _remote_device
|
||||
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed.checkpoint.default_planner import (
|
||||
DefaultLoadPlanner,
|
||||
)
|
||||
from torch.distributed.checkpoint.planner import LoadPlanner
|
||||
|
||||
from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
|
||||
from torch.distributed.checkpoint.state_dict_loader import load_state_dict
|
||||
from torch.distributed.checkpoint.storage import StorageReader
|
||||
from torch.distributed.checkpoint.utils import (
|
||||
_element_wise_add,
|
||||
_element_wise_sub,
|
||||
_normalize_device_info
|
||||
_normalize_device_info,
|
||||
)
|
||||
|
||||
from torch._utils import _get_device_module
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
|
||||
from torch.distributed.remote_device import _remote_device
|
||||
|
||||
STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]]
|
||||
|
||||
|
|
@ -59,7 +51,9 @@ def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str:
|
|||
return "cpu"
|
||||
device_module = _get_device_module(device_type)
|
||||
if device_module.is_available():
|
||||
return _normalize_device_info(device_type, global_rank % device_module.device_count())
|
||||
return _normalize_device_info(
|
||||
device_type, global_rank % device_module.device_count()
|
||||
)
|
||||
return "cpu"
|
||||
|
||||
|
||||
|
|
@ -90,18 +84,17 @@ def _is_nested_tensor(val: torch.Tensor) -> bool:
|
|||
if type(val.local_shards()[0].tensor) is ShardedTensor:
|
||||
return True
|
||||
if type(val.local_shards()[0].tensor) is DTensor:
|
||||
raise ValueError(
|
||||
"Cannot handle DTensor nested insided ShardedTensor"
|
||||
)
|
||||
raise ValueError("Cannot handle DTensor nested insided ShardedTensor")
|
||||
elif type(val) is DTensor and (
|
||||
type(val._local_tensor) is DTensor
|
||||
or type(val._local_tensor) is ShardedTensor
|
||||
type(val._local_tensor) is DTensor or type(val._local_tensor) is ShardedTensor
|
||||
):
|
||||
raise ValueError("Cannot handle nested DTensor")
|
||||
return False
|
||||
|
||||
|
||||
def _alloc_tensor(props: TensorProperties, size: Sequence[int], device_type: str = "cuda") -> torch.Tensor:
|
||||
def _alloc_tensor(
|
||||
props: TensorProperties, size: Sequence[int], device_type: str = "cuda"
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
size=size,
|
||||
dtype=props.dtype,
|
||||
|
|
@ -181,9 +174,7 @@ class _ReaderWithOffset(DefaultLoadPlanner):
|
|||
local_chunks = [
|
||||
ChunkStorageMetadata(
|
||||
offsets=torch.Size(
|
||||
_element_wise_add(
|
||||
original_shard.metadata.shard_offsets, offset
|
||||
)
|
||||
_element_wise_add(original_shard.metadata.shard_offsets, offset)
|
||||
),
|
||||
sizes=torch.Size(original_shard.metadata.shard_sizes),
|
||||
)
|
||||
|
|
@ -196,9 +187,7 @@ class _ReaderWithOffset(DefaultLoadPlanner):
|
|||
# TODO: we should change _create_sharded_read_items to have more ergonomic API
|
||||
for ri in reqs:
|
||||
assert ri.dest_index.offset is not None
|
||||
original_offset = _element_wise_sub(
|
||||
ri.dest_index.offset, offset
|
||||
)
|
||||
original_offset = _element_wise_sub(ri.dest_index.offset, offset)
|
||||
original_index = dataclasses.replace(
|
||||
ri.dest_index, offset=torch.Size(original_offset)
|
||||
)
|
||||
|
|
@ -214,7 +203,7 @@ class _ReaderWithOffset(DefaultLoadPlanner):
|
|||
def load_sharded_optimizer_state_dict(
|
||||
model_state_dict: STATE_DICT_TYPE,
|
||||
optimizer_key: str,
|
||||
storage_reader: dist_cp.StorageReader,
|
||||
storage_reader: StorageReader,
|
||||
planner: Optional[LoadPlanner] = None,
|
||||
) -> STATE_DICT_TYPE:
|
||||
"""
|
||||
|
|
@ -273,7 +262,9 @@ def load_sharded_optimizer_state_dict(
|
|||
if dp_pg is None:
|
||||
placements = []
|
||||
for i in range(dist.get_world_size()):
|
||||
device_info = _normalize_device_info(dp_pg_device_type, i % device_module.device_count())
|
||||
device_info = _normalize_device_info(
|
||||
dp_pg_device_type, i % device_module.device_count()
|
||||
)
|
||||
placements.append(f"rank:{i}/{device_info}")
|
||||
sharding_spec = ChunkShardingSpec(dim=0, placements=placements) # type: ignore[arg-type]
|
||||
else:
|
||||
|
|
@ -294,7 +285,9 @@ def load_sharded_optimizer_state_dict(
|
|||
|
||||
# value: TensorStorageMetadata
|
||||
if value.size.numel() == 1:
|
||||
state_dict[key] = _alloc_tensor(value.properties, value.size, dp_pg_device_type)
|
||||
state_dict[key] = _alloc_tensor(
|
||||
value.properties, value.size, dp_pg_device_type
|
||||
)
|
||||
elif dp_pg is None:
|
||||
state_dict[key] = _create_chunk_sharded_tensor(
|
||||
_alloc_tensor(value.properties, value.size, dp_pg_device_type),
|
||||
|
|
@ -313,10 +306,7 @@ def load_sharded_optimizer_state_dict(
|
|||
local_shards = []
|
||||
current_rank = dist.get_rank(dp_pg)
|
||||
for shard_md in st_md.shards_metadata:
|
||||
if (
|
||||
cast(_remote_device, shard_md.placement).rank()
|
||||
!= current_rank
|
||||
):
|
||||
if cast(_remote_device, shard_md.placement).rank() != current_rank:
|
||||
continue
|
||||
local_shards.append(
|
||||
Shard(
|
||||
|
|
@ -331,18 +321,13 @@ def load_sharded_optimizer_state_dict(
|
|||
local_shards, st_md, process_group=dp_pg
|
||||
)
|
||||
|
||||
if (
|
||||
spec_key in layout_specs
|
||||
and layout_specs[spec_key][0] is not None
|
||||
):
|
||||
fqn_to_offset[key] = cast(
|
||||
Sequence[int], layout_specs[spec_key][0]
|
||||
)
|
||||
if spec_key in layout_specs and layout_specs[spec_key][0] is not None:
|
||||
fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0])
|
||||
|
||||
state_dict[key] = st
|
||||
|
||||
# Whether we unflatten before or after doesn't matter
|
||||
dist_cp.load_state_dict(
|
||||
load_state_dict(
|
||||
state_dict=state_dict,
|
||||
storage_reader=storage_reader,
|
||||
# FIXME the type of planner is wrong in load_state_dict
|
||||
|
|
|
|||
|
|
@ -1,19 +1,13 @@
|
|||
import abc
|
||||
from dataclasses import dataclass
|
||||
import io
|
||||
from typing import List, Tuple, Any, Union, Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
from enum import Enum, auto
|
||||
import torch
|
||||
|
||||
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
|
||||
|
||||
from .metadata import (
|
||||
ChunkStorageMetadata,
|
||||
MetadataIndex,
|
||||
Metadata,
|
||||
STATE_DICT_TYPE,
|
||||
)
|
||||
from .metadata import ChunkStorageMetadata, Metadata, MetadataIndex, STATE_DICT_TYPE
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -223,9 +217,7 @@ class SavePlanner(abc.ABC):
|
|||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def resolve_data(
|
||||
self, write_item: WriteItem
|
||||
) -> Union[torch.Tensor, io.BytesIO]:
|
||||
def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
|
||||
"""
|
||||
Transform and prepare ``write_item`` from ``state_dict`` for storage, ensuring idempotency and thread-safety.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
|
||||
from torch.distributed._shard.metadata import ShardMetadata
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
|
||||
|
|
@ -16,7 +15,6 @@ from .metadata import (
|
|||
STORAGE_TYPES,
|
||||
TensorStorageMetadata,
|
||||
)
|
||||
|
||||
from .planner import (
|
||||
LoadItemType,
|
||||
ReadItem,
|
||||
|
|
@ -25,7 +23,6 @@ from .planner import (
|
|||
WriteItem,
|
||||
WriteItemType,
|
||||
)
|
||||
|
||||
from .resharding import (
|
||||
_check_shard_metadata_pair_overlap,
|
||||
_shards_get_overlap_region_wrt_saved_tensor,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
from typing import List, Tuple
|
||||
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
ChunkStorageMetadata
|
||||
)
|
||||
from torch.distributed.checkpoint.metadata import ChunkStorageMetadata
|
||||
|
||||
__all__: List[str] = []
|
||||
|
||||
def _check_shard_metadata_pair_overlap(shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata):
|
||||
|
||||
def _check_shard_metadata_pair_overlap(
|
||||
shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata
|
||||
):
|
||||
"""Check if two shards overlap."""
|
||||
# For each dim of each shard, check if one shard resides on the other
|
||||
# end of second shard with respect to that dim. As an example for a 2D
|
||||
|
|
@ -21,6 +22,7 @@ def _check_shard_metadata_pair_overlap(shard1: ChunkStorageMetadata, shard2: Chu
|
|||
|
||||
return True
|
||||
|
||||
|
||||
def _shards_get_overlap_region_wrt_saved_tensor(
|
||||
saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata
|
||||
) -> List[Tuple[int, int, int, int]]:
|
||||
|
|
@ -56,9 +58,7 @@ def _shards_get_overlap_region_wrt_saved_tensor(
|
|||
|
||||
if saved_shard_offset > current_shard_offset:
|
||||
offset_for_saved_tensor = 0
|
||||
offset_for_current_tensor = (
|
||||
saved_shard_offset - current_shard_offset
|
||||
)
|
||||
offset_for_current_tensor = saved_shard_offset - current_shard_offset
|
||||
else:
|
||||
offset_for_saved_tensor = current_shard_offset - saved_shard_offset
|
||||
offset_for_current_tensor = 0
|
||||
|
|
|
|||
|
|
@ -1,20 +1,18 @@
|
|||
from typing import Any, Dict, Optional
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.checkpoint.stateful import Stateful
|
||||
|
||||
from .storage import (
|
||||
StorageReader,
|
||||
)
|
||||
from .planner import LoadPlanner
|
||||
from .default_planner import DefaultLoadPlanner
|
||||
|
||||
from .utils import _DistWrapper, _all_gather_keys
|
||||
from .planner import LoadPlanner
|
||||
from .storage import StorageReader
|
||||
from .utils import _all_gather_keys, _DistWrapper
|
||||
|
||||
__all__ = ["load_state_dict", "load"]
|
||||
|
||||
|
||||
def load_state_dict(
|
||||
state_dict: Dict[str, Any],
|
||||
storage_reader: StorageReader,
|
||||
|
|
@ -28,7 +26,10 @@ def load_state_dict(
|
|||
"'load_state_dict' is deprecated and will be removed in future versions. Please use 'load' instead."
|
||||
)
|
||||
# TODO: test returning `load` here instead.
|
||||
return _load_state_dict(state_dict, storage_reader, process_group, coordinator_rank, no_dist, planner)
|
||||
return _load_state_dict(
|
||||
state_dict, storage_reader, process_group, coordinator_rank, no_dist, planner
|
||||
)
|
||||
|
||||
|
||||
def load(
|
||||
state_dict: Dict[str, Any],
|
||||
|
|
@ -124,7 +125,9 @@ def load(
|
|||
elem = state_dict[key]
|
||||
statetful_sd[key] = elem.state_dict() if isinstance(elem, Stateful) else elem
|
||||
|
||||
_load_state_dict(statetful_sd, storage_reader, process_group, coordinator_rank, no_dist, planner)
|
||||
_load_state_dict(
|
||||
statetful_sd, storage_reader, process_group, coordinator_rank, no_dist, planner
|
||||
)
|
||||
for key in keys:
|
||||
if key not in state_dict:
|
||||
continue
|
||||
|
|
@ -133,6 +136,7 @@ def load(
|
|||
elem.load_state_dict(statetful_sd[key])
|
||||
state_dict[key] = elem
|
||||
|
||||
|
||||
def _load_state_dict(
|
||||
state_dict: Dict[str, Any],
|
||||
storage_reader: StorageReader,
|
||||
|
|
@ -141,7 +145,6 @@ def _load_state_dict(
|
|||
no_dist: bool = False,
|
||||
planner: Optional[LoadPlanner] = None,
|
||||
) -> None:
|
||||
|
||||
torch._C._log_api_usage_once("torch.distributed.checkpoint.load_state_dict")
|
||||
|
||||
distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
|
||||
|
|
|
|||
|
|
@ -1,18 +1,14 @@
|
|||
from typing import Optional
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.checkpoint.stateful import Stateful
|
||||
from .planner import SavePlanner
|
||||
|
||||
from .default_planner import DefaultSavePlanner
|
||||
|
||||
|
||||
from .storage import (
|
||||
StorageWriter,
|
||||
)
|
||||
|
||||
from .metadata import Metadata, STATE_DICT_TYPE
|
||||
from .planner import SavePlanner
|
||||
from .storage import StorageWriter
|
||||
from .utils import _DistWrapper
|
||||
|
||||
__all__ = ["save_state_dict", "save"]
|
||||
|
|
@ -32,7 +28,10 @@ def save_state_dict(
|
|||
)
|
||||
|
||||
# TODO: test returning `save` here instead.
|
||||
return _save_state_dict(state_dict, storage_writer, process_group, coordinator_rank, no_dist, planner)
|
||||
return _save_state_dict(
|
||||
state_dict, storage_writer, process_group, coordinator_rank, no_dist, planner
|
||||
)
|
||||
|
||||
|
||||
def save(
|
||||
state_dict: STATE_DICT_TYPE,
|
||||
|
|
@ -108,7 +107,9 @@ def save(
|
|||
|
||||
dumpable_state_dict = {}
|
||||
for key, elem in state_dict.items():
|
||||
dumpable_state_dict[key] = elem.state_dict() if isinstance(elem, Stateful) else elem
|
||||
dumpable_state_dict[key] = (
|
||||
elem.state_dict() if isinstance(elem, Stateful) else elem
|
||||
)
|
||||
|
||||
return _save_state_dict(
|
||||
dumpable_state_dict,
|
||||
|
|
@ -116,9 +117,10 @@ def save(
|
|||
process_group,
|
||||
coordinator_rank,
|
||||
no_dist,
|
||||
planner
|
||||
planner,
|
||||
)
|
||||
|
||||
|
||||
def _save_state_dict(
|
||||
state_dict: STATE_DICT_TYPE,
|
||||
storage_writer: StorageWriter,
|
||||
|
|
@ -127,7 +129,6 @@ def _save_state_dict(
|
|||
no_dist: bool = False,
|
||||
planner: Optional[SavePlanner] = None,
|
||||
) -> Metadata:
|
||||
|
||||
torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict")
|
||||
|
||||
distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
|
||||
|
|
@ -149,9 +150,7 @@ def _save_state_dict(
|
|||
nonlocal global_metatadata
|
||||
|
||||
assert planner is not None
|
||||
all_local_plans, global_metatadata = planner.create_global_plan(
|
||||
all_local_plans
|
||||
)
|
||||
all_local_plans, global_metatadata = planner.create_global_plan(all_local_plans)
|
||||
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
|
||||
return all_local_plans
|
||||
|
||||
|
|
|
|||
|
|
@ -1,20 +1,11 @@
|
|||
import abc
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Any
|
||||
from typing import Any, List
|
||||
|
||||
from torch.futures import Future
|
||||
|
||||
from .metadata import (
|
||||
Metadata,
|
||||
MetadataIndex,
|
||||
)
|
||||
|
||||
from .planner import (
|
||||
LoadPlan,
|
||||
SavePlan,
|
||||
SavePlanner,
|
||||
LoadPlanner,
|
||||
)
|
||||
from .metadata import Metadata, MetadataIndex
|
||||
from .planner import LoadPlan, LoadPlanner, SavePlan, SavePlanner
|
||||
|
||||
__all__ = ["WriteResult", "StorageWriter", "StorageReader"]
|
||||
|
||||
|
|
@ -115,9 +106,7 @@ class StorageWriter(abc.ABC):
|
|||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def finish(
|
||||
self, metadata: Metadata, results: List[List[WriteResult]]
|
||||
) -> None:
|
||||
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
|
||||
"""
|
||||
Write the metadata and marks the current checkpoint as successful.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,37 +1,21 @@
|
|||
import os
|
||||
import io
|
||||
import itertools
|
||||
from typing import (
|
||||
List,
|
||||
Callable,
|
||||
Optional,
|
||||
Union,
|
||||
TypeVar,
|
||||
Dict,
|
||||
Any,
|
||||
cast,
|
||||
Sequence
|
||||
)
|
||||
import torch.distributed as dist
|
||||
from .api import (
|
||||
CheckpointException,
|
||||
_wrap_exception,
|
||||
_is_wrapped_exception,
|
||||
WRAPPED_EXCEPTION,
|
||||
)
|
||||
import os
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TypeVar, Union
|
||||
|
||||
import torch
|
||||
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
ShardedTensor,
|
||||
)
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
from torch.distributed._shard.sharded_tensor.shard import Shard
|
||||
from torch.distributed._tensor import DTensor
|
||||
|
||||
from .metadata import (
|
||||
STATE_DICT_TYPE,
|
||||
MetadataIndex,
|
||||
from .api import (
|
||||
_is_wrapped_exception,
|
||||
_wrap_exception,
|
||||
CheckpointException,
|
||||
WRAPPED_EXCEPTION,
|
||||
)
|
||||
from .metadata import MetadataIndex, STATE_DICT_TYPE
|
||||
|
||||
__all__ = ["find_tensor_shard", "find_state_dict_object"]
|
||||
|
||||
|
|
@ -47,6 +31,7 @@ def _get_failure_dict(
|
|||
{i: err for i, err in enumerate(results) if _is_wrapped_exception(err)},
|
||||
)
|
||||
|
||||
|
||||
def _all_gather_keys(local_dict: Dict[Any, Any]) -> List[Any]:
|
||||
"""Gathers all keys, and returns them sorted."""
|
||||
keys = list(local_dict.keys())
|
||||
|
|
@ -55,6 +40,7 @@ def _all_gather_keys(local_dict: Dict[Any, Any]) -> List[Any]:
|
|||
dist.all_gather_object(gathered_keys, keys)
|
||||
return sorted(set(itertools.chain.from_iterable(gathered_keys)))
|
||||
|
||||
|
||||
class _DistWrapper:
|
||||
"""
|
||||
This is a wrapper around PG that provides a series of features around object collectives.
|
||||
|
|
@ -123,9 +109,7 @@ class _DistWrapper:
|
|||
def all_gather_object(self, object: T) -> List[T]:
|
||||
"""Implement functionality similar to c10d::all_gather_object but without distributed enabled."""
|
||||
if self.use_dist:
|
||||
gather_objs = cast(
|
||||
List[T], [None] * dist.get_world_size(self.group)
|
||||
)
|
||||
gather_objs = cast(List[T], [None] * dist.get_world_size(self.group))
|
||||
|
||||
dist.all_gather_object(
|
||||
object_list=gather_objs, obj=object, group=self.group
|
||||
|
|
@ -140,9 +124,7 @@ class _DistWrapper:
|
|||
gather_result = cast(List[T], [None])
|
||||
dist.scatter_object_list(
|
||||
scatter_object_output_list=gather_result,
|
||||
scatter_object_input_list=object_list
|
||||
if self.is_coordinator
|
||||
else None,
|
||||
scatter_object_input_list=object_list if self.is_coordinator else None,
|
||||
src=self.coordinator_rank,
|
||||
group=self.group,
|
||||
)
|
||||
|
|
@ -282,9 +264,7 @@ class _DistWrapper:
|
|||
try:
|
||||
result = map_fun()
|
||||
except BaseException as e:
|
||||
result = CheckpointException(
|
||||
step, {self.rank: _wrap_exception(e)}
|
||||
)
|
||||
result = CheckpointException(step, {self.rank: _wrap_exception(e)})
|
||||
final_result = self.broadcast_object(result)
|
||||
if isinstance(final_result, CheckpointException):
|
||||
raise final_result
|
||||
|
|
@ -302,22 +282,17 @@ def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard:
|
|||
if index.index is not None:
|
||||
if (
|
||||
len(shards) > index.index
|
||||
and torch.Size(shards[index.index].metadata.shard_offsets)
|
||||
== index.offset
|
||||
and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset
|
||||
):
|
||||
return shards[index.index]
|
||||
|
||||
for shard in shards:
|
||||
if torch.Size(shard.metadata.shard_offsets) == index.offset:
|
||||
return shard
|
||||
raise ValueError(
|
||||
f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'"
|
||||
)
|
||||
raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'")
|
||||
|
||||
|
||||
def find_tensor_shard(
|
||||
tensor: torch.Tensor, index: MetadataIndex
|
||||
) -> torch.Tensor:
|
||||
def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor:
|
||||
if isinstance(tensor, DTensor):
|
||||
return tensor.to_local()
|
||||
if isinstance(tensor, ShardedTensor):
|
||||
|
|
@ -332,9 +307,7 @@ def find_tensor_shard(
|
|||
return tensor
|
||||
|
||||
|
||||
def find_state_dict_object(
|
||||
state_dict: STATE_DICT_TYPE, index: MetadataIndex
|
||||
) -> Any:
|
||||
def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any:
|
||||
if index.fqn not in state_dict:
|
||||
raise ValueError(f"Could not find FQN: '{index.fqn}'")
|
||||
obj = state_dict[index.fqn]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user