mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable custom device support in fsdp checkpoint (#107289)
Fixes https://github.com/pytorch/pytorch/issues/104390 Enable custom device(privateuse1 backend) support in checkpointing by a dynamic abstract device module. Pull Request resolved: https://github.com/pytorch/pytorch/pull/107289 Approved by: https://github.com/wz337
This commit is contained in:
parent
b18e1b684a
commit
ff37f6018d
|
|
@ -1,4 +1,5 @@
|
||||||
import copyreg
|
import copyreg
|
||||||
|
import functools
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
|
|
@ -839,3 +840,13 @@ def classproperty(func):
|
||||||
# Whether we are compiling with torch.compile or not
|
# Whether we are compiling with torch.compile or not
|
||||||
def is_compiling():
|
def is_compiling():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(2)
|
||||||
|
def _get_device_module(device_type: str):
|
||||||
|
device_module = getattr(torch, device_type, None)
|
||||||
|
if device_module is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'."
|
||||||
|
)
|
||||||
|
return device_module
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ import fsspec
|
||||||
import torch
|
import torch
|
||||||
from fsspec.core import url_to_fs
|
from fsspec.core import url_to_fs
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch._utils import _get_device_module
|
||||||
|
|
||||||
from torch.distributed._shard._utils import narrow_tensor_by_index
|
from torch.distributed._shard._utils import narrow_tensor_by_index
|
||||||
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
|
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
|
||||||
|
|
@ -114,7 +115,7 @@ class _OverlappingCpuLoader(_TensorLoader):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
resolve_fun: Callable,
|
resolve_fun: Callable,
|
||||||
stream: Union[None, io.RawIOBase, torch._C._CudaStreamBase] = None,
|
stream: Union[None, io.RawIOBase, torch.Stream] = None,
|
||||||
inflight_threshhold: int = 1_000_000,
|
inflight_threshhold: int = 1_000_000,
|
||||||
):
|
):
|
||||||
self.resolve_fun = resolve_fun
|
self.resolve_fun = resolve_fun
|
||||||
|
|
@ -124,9 +125,11 @@ class _OverlappingCpuLoader(_TensorLoader):
|
||||||
self.current_items: collections.deque = collections.deque()
|
self.current_items: collections.deque = collections.deque()
|
||||||
self.idx = 0
|
self.idx = 0
|
||||||
self.started = False
|
self.started = False
|
||||||
self.stream = stream or torch.cuda.current_stream()
|
self.device_type = stream.device_type if stream else torch.device("cuda").type
|
||||||
if self.stream != torch.cuda.current_stream():
|
self.device_module = _get_device_module(self.device_type)
|
||||||
self.stream.wait_stream(torch.cuda.current_stream())
|
self.stream = stream or self.device_module.current_stream()
|
||||||
|
if self.stream != self.device_module.current_stream():
|
||||||
|
self.stream.wait_stream(self.device_module.current_stream())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _done(self):
|
def _done(self):
|
||||||
|
|
@ -143,7 +146,7 @@ class _OverlappingCpuLoader(_TensorLoader):
|
||||||
return drained
|
return drained
|
||||||
|
|
||||||
def _refill(self):
|
def _refill(self):
|
||||||
with torch.cuda.stream(self.stream):
|
with self.device_module.stream(self.stream):
|
||||||
while (
|
while (
|
||||||
not self._done
|
not self._done
|
||||||
and self.in_flight_data < self.inflight_threshhold
|
and self.in_flight_data < self.inflight_threshhold
|
||||||
|
|
@ -151,7 +154,7 @@ class _OverlappingCpuLoader(_TensorLoader):
|
||||||
_, obj = self.items[self.idx]
|
_, obj = self.items[self.idx]
|
||||||
self.idx += 1
|
self.idx += 1
|
||||||
tensor = self.resolve_fun(obj).detach()
|
tensor = self.resolve_fun(obj).detach()
|
||||||
if tensor.is_cuda:
|
if tensor.device.type == self.device_type:
|
||||||
tensor = tensor.to(device="cpu", non_blocking=True)
|
tensor = tensor.to(device="cpu", non_blocking=True)
|
||||||
elif tensor.device == torch.device("cpu"):
|
elif tensor.device == torch.device("cpu"):
|
||||||
if tensor.storage().size() != tensor.numel():
|
if tensor.storage().size() != tensor.numel():
|
||||||
|
|
@ -232,7 +235,7 @@ def _split_by_size_and_type(
|
||||||
|
|
||||||
|
|
||||||
def _write_item(
|
def _write_item(
|
||||||
stream: Optional[Union[io.RawIOBase, torch._C._CudaStreamBase]],
|
stream: Optional[Union[io.RawIOBase, torch.Stream]],
|
||||||
data: Union[io.BytesIO, torch.Tensor],
|
data: Union[io.BytesIO, torch.Tensor],
|
||||||
write_item: WriteItem,
|
write_item: WriteItem,
|
||||||
storage_key: str,
|
storage_key: str,
|
||||||
|
|
@ -294,7 +297,7 @@ def _write_files_from_queue(
|
||||||
)
|
)
|
||||||
|
|
||||||
for tensor, write_item in loader.values():
|
for tensor, write_item in loader.values():
|
||||||
assert not tensor.is_cuda
|
assert tensor.is_cpu
|
||||||
write_results.append(
|
write_results.append(
|
||||||
_write_item(stream, tensor, write_item, storage_key)
|
_write_item(stream, tensor, write_item, storage_key)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from ._traverse import (
|
||||||
STATE_DICT_ITEM,
|
STATE_DICT_ITEM,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .utils import _element_wise_add
|
from .utils import _element_wise_add, _normalize_device_info
|
||||||
|
|
||||||
|
|
||||||
# TODO: We need to refactor this code.
|
# TODO: We need to refactor this code.
|
||||||
|
|
@ -83,6 +83,7 @@ def _flatten_sharded_tensors(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
|
||||||
|
|
||||||
st_meta: ShardedTensorMetadata = copy.deepcopy(value.metadata())
|
st_meta: ShardedTensorMetadata = copy.deepcopy(value.metadata())
|
||||||
other_rank = 0 if dist.get_rank() > 0 else 1
|
other_rank = 0 if dist.get_rank() > 0 else 1
|
||||||
|
device_info = _normalize_device_info(inner_shard.tensor.device.type, 0)
|
||||||
|
|
||||||
# Remove the outer ST shard the inner ST covers
|
# Remove the outer ST shard the inner ST covers
|
||||||
for i, shard_md in enumerate(st_meta.shards_metadata):
|
for i, shard_md in enumerate(st_meta.shards_metadata):
|
||||||
|
|
@ -92,7 +93,7 @@ def _flatten_sharded_tensors(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
|
||||||
|
|
||||||
# Attribute other rank for the other shards
|
# Attribute other rank for the other shards
|
||||||
for shard_md in st_meta.shards_metadata:
|
for shard_md in st_meta.shards_metadata:
|
||||||
shard_md.placement = _remote_device(f"rank:{other_rank}/cuda:0")
|
shard_md.placement = _remote_device(f"rank:{other_rank}/{device_info}")
|
||||||
|
|
||||||
# Add other inner shards from the inner tensor
|
# Add other inner shards from the inner tensor
|
||||||
for inner_md in inner_st.metadata().shards_metadata:
|
for inner_md in inner_st.metadata().shards_metadata:
|
||||||
|
|
@ -104,7 +105,7 @@ def _flatten_sharded_tensors(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
|
||||||
inner_md.shard_offsets,
|
inner_md.shard_offsets,
|
||||||
),
|
),
|
||||||
shard_sizes=inner_md.shard_sizes,
|
shard_sizes=inner_md.shard_sizes,
|
||||||
placement=f"rank:{other_rank}/cuda:0",
|
placement=f"rank:{other_rank}/{device_info}",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ from .planner import (
|
||||||
from .utils import _create_file_view
|
from .utils import _create_file_view
|
||||||
|
|
||||||
from torch.distributed._shard._utils import narrow_tensor_by_index
|
from torch.distributed._shard._utils import narrow_tensor_by_index
|
||||||
|
from torch._utils import _get_device_module
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FileSystemWriter",
|
"FileSystemWriter",
|
||||||
|
|
@ -126,9 +127,11 @@ class _OverlappingCpuLoader(_TensorLoader):
|
||||||
self.current_items: collections.deque = collections.deque()
|
self.current_items: collections.deque = collections.deque()
|
||||||
self.idx = 0
|
self.idx = 0
|
||||||
self.started = False
|
self.started = False
|
||||||
self.stream = stream or torch.cuda.current_stream()
|
self.device_type = stream.device_type if stream else torch.device("cuda").type
|
||||||
if self.stream != torch.cuda.current_stream():
|
self.device_module = _get_device_module(self.device_type)
|
||||||
self.stream.wait_stream(torch.cuda.current_stream())
|
self.stream = stream or self.device_module.current_stream()
|
||||||
|
if self.stream != self.device_module.current_stream():
|
||||||
|
self.stream.wait_stream(self.device_module.current_stream())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _done(self):
|
def _done(self):
|
||||||
|
|
@ -145,7 +148,7 @@ class _OverlappingCpuLoader(_TensorLoader):
|
||||||
return drained
|
return drained
|
||||||
|
|
||||||
def _refill(self):
|
def _refill(self):
|
||||||
with torch.cuda.stream(self.stream):
|
with self.device_module.stream(self.stream):
|
||||||
while (
|
while (
|
||||||
not self._done
|
not self._done
|
||||||
and self.in_flight_data < self.inflight_threshhold
|
and self.in_flight_data < self.inflight_threshhold
|
||||||
|
|
@ -153,7 +156,7 @@ class _OverlappingCpuLoader(_TensorLoader):
|
||||||
_, obj = self.items[self.idx]
|
_, obj = self.items[self.idx]
|
||||||
self.idx += 1
|
self.idx += 1
|
||||||
tensor = self.resolve_fun(obj).detach()
|
tensor = self.resolve_fun(obj).detach()
|
||||||
if tensor.is_cuda:
|
if tensor.device.type == self.device_type:
|
||||||
tensor = tensor.to(device="cpu", non_blocking=True)
|
tensor = tensor.to(device="cpu", non_blocking=True)
|
||||||
elif tensor.device == torch.device("cpu"):
|
elif tensor.device == torch.device("cpu"):
|
||||||
if tensor.storage().size() != tensor.numel():
|
if tensor.storage().size() != tensor.numel():
|
||||||
|
|
@ -292,7 +295,7 @@ def _write_files_from_queue(
|
||||||
)
|
)
|
||||||
|
|
||||||
for tensor, write_item in loader.values():
|
for tensor, write_item in loader.values():
|
||||||
assert not tensor.is_cuda
|
assert tensor.is_cpu
|
||||||
write_results.append(
|
write_results.append(
|
||||||
_write_item(stream, tensor, write_item, storage_key)
|
_write_item(stream, tensor, write_item, storage_key)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -38,8 +38,11 @@ from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
|
||||||
from torch.distributed.checkpoint.utils import (
|
from torch.distributed.checkpoint.utils import (
|
||||||
_element_wise_add,
|
_element_wise_add,
|
||||||
_element_wise_sub,
|
_element_wise_sub,
|
||||||
|
_normalize_device_info
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from torch._utils import _get_device_module
|
||||||
|
|
||||||
STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]]
|
STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -49,23 +52,27 @@ __all__ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _gen_rank_device(global_rank: int) -> str:
|
def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str:
|
||||||
if torch.cuda.is_available():
|
if device_type == "cpu":
|
||||||
return f"cuda:{global_rank % torch.cuda.device_count()}"
|
return "cpu"
|
||||||
|
device_module = _get_device_module(device_type)
|
||||||
|
if device_module.is_available():
|
||||||
|
return _normalize_device_info(device_type, global_rank % device_module.device_count())
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
def _create_colwise_spec(
|
def _create_colwise_spec(
|
||||||
pg: Optional[dist.ProcessGroup] = None,
|
pg: Optional[dist.ProcessGroup] = None,
|
||||||
) -> ChunkShardingSpec:
|
) -> ChunkShardingSpec:
|
||||||
|
pg_device_type = dist.distributed_c10d._get_pg_default_device(pg).type
|
||||||
if pg is None:
|
if pg is None:
|
||||||
placements = [
|
placements = [
|
||||||
f"rank:{idx}/{_gen_rank_device(idx)}"
|
f"rank:{idx}/{_gen_rank_device(idx, pg_device_type)}"
|
||||||
for idx in range(dist.get_world_size())
|
for idx in range(dist.get_world_size())
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
placements = [
|
placements = [
|
||||||
f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx))}"
|
f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx), pg_device_type)}"
|
||||||
for idx in range(pg.size())
|
for idx in range(pg.size())
|
||||||
]
|
]
|
||||||
return ChunkShardingSpec(
|
return ChunkShardingSpec(
|
||||||
|
|
@ -92,14 +99,14 @@ def _is_nested_tensor(val: torch.Tensor) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _alloc_tensor(props: TensorProperties, size: Sequence[int]) -> torch.Tensor:
|
def _alloc_tensor(props: TensorProperties, size: Sequence[int], device_type: str = "cuda") -> torch.Tensor:
|
||||||
return torch.empty(
|
return torch.empty(
|
||||||
size=size,
|
size=size,
|
||||||
dtype=props.dtype,
|
dtype=props.dtype,
|
||||||
layout=props.layout,
|
layout=props.layout,
|
||||||
requires_grad=props.requires_grad,
|
requires_grad=props.requires_grad,
|
||||||
pin_memory=props.pin_memory,
|
pin_memory=props.pin_memory,
|
||||||
device=cast(torch.device, torch.cuda.current_device()),
|
device=cast(torch.device, _get_device_module(device_type).current_device()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -255,15 +262,15 @@ def load_sharded_optimizer_state_dict(
|
||||||
metadata = storage_reader.read_metadata()
|
metadata = storage_reader.read_metadata()
|
||||||
|
|
||||||
layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict)
|
layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict)
|
||||||
|
dp_pg_device_type = dist.distributed_c10d._get_pg_default_device(dp_pg).type
|
||||||
|
device_module = _get_device_module(dp_pg_device_type)
|
||||||
|
|
||||||
if dp_pg is None:
|
if dp_pg is None:
|
||||||
sharding_spec = ChunkShardingSpec(
|
placements = []
|
||||||
dim=0,
|
for i in range(dist.get_world_size()):
|
||||||
placements=[
|
device_info = _normalize_device_info(dp_pg_device_type, i % device_module.device_count())
|
||||||
f"rank:{i}/cuda:{i % torch.cuda.device_count()}"
|
placements.append(f"rank:{i}/{device_info}")
|
||||||
for i in range(dist.get_world_size())
|
sharding_spec = ChunkShardingSpec(dim=0, placements=placements) # type: ignore[arg-type]
|
||||||
],
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
sharding_spec = _create_colwise_spec(dp_pg)
|
sharding_spec = _create_colwise_spec(dp_pg)
|
||||||
|
|
||||||
|
|
@ -282,10 +289,10 @@ def load_sharded_optimizer_state_dict(
|
||||||
|
|
||||||
# value: TensorStorageMetadata
|
# value: TensorStorageMetadata
|
||||||
if value.size.numel() == 1:
|
if value.size.numel() == 1:
|
||||||
state_dict[key] = _alloc_tensor(value.properties, value.size)
|
state_dict[key] = _alloc_tensor(value.properties, value.size, dp_pg_device_type)
|
||||||
elif dp_pg is None:
|
elif dp_pg is None:
|
||||||
state_dict[key] = _shard_tensor(
|
state_dict[key] = _shard_tensor(
|
||||||
_alloc_tensor(value.properties, value.size), sharding_spec
|
_alloc_tensor(value.properties, value.size, dp_pg_device_type), sharding_spec
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
spec_key = key_path[2]
|
spec_key = key_path[2]
|
||||||
|
|
@ -305,7 +312,7 @@ def load_sharded_optimizer_state_dict(
|
||||||
local_shards.append(
|
local_shards.append(
|
||||||
Shard(
|
Shard(
|
||||||
tensor=_alloc_tensor(
|
tensor=_alloc_tensor(
|
||||||
value.properties, shard_md.shard_sizes
|
value.properties, shard_md.shard_sizes, dp_pg_device_type
|
||||||
),
|
),
|
||||||
metadata=shard_md,
|
metadata=shard_md,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -355,6 +355,7 @@ def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> List[int]:
|
||||||
def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> List[int]:
|
def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> List[int]:
|
||||||
return [i_a - i_b for i_a, i_b in zip(a, b)]
|
return [i_a - i_b for i_a, i_b in zip(a, b)]
|
||||||
|
|
||||||
|
|
||||||
class _ReaderView(io.IOBase):
|
class _ReaderView(io.IOBase):
|
||||||
def __init__(self, base_stream: io.IOBase, offset: int, len: int):
|
def __init__(self, base_stream: io.IOBase, offset: int, len: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -386,6 +387,16 @@ class _ReaderView(io.IOBase):
|
||||||
def read(self, size=-1):
|
def read(self, size=-1):
|
||||||
return self.base_stream.read(size)
|
return self.base_stream.read(size)
|
||||||
|
|
||||||
|
|
||||||
def _create_file_view(file: io.IOBase, offset: int, length: int) -> io.IOBase:
|
def _create_file_view(file: io.IOBase, offset: int, length: int) -> io.IOBase:
|
||||||
# FIXME (kumpera) torch.load fails if we wrap with io.BufferedReader
|
# FIXME (kumpera) torch.load fails if we wrap with io.BufferedReader
|
||||||
return _ReaderView(file, offset, length)
|
return _ReaderView(file, offset, length)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_device_info(device_type: str, device_id: int) -> str:
|
||||||
|
"""
|
||||||
|
Device info normalization.
|
||||||
|
"""
|
||||||
|
if device_type == "cpu":
|
||||||
|
return "cpu"
|
||||||
|
return f"{device_type}:{device_id}"
|
||||||
|
|
|
||||||
|
|
@ -107,7 +107,7 @@ def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies):
|
||||||
with device_mod.stream(stream):
|
with device_mod.stream(stream):
|
||||||
output = obj.to(target_device)
|
output = obj.to(target_device)
|
||||||
# synchronize with the copy stream
|
# synchronize with the copy stream
|
||||||
with torch.cuda.device(target_device.index):
|
with device_mod.device(target_device.index):
|
||||||
current_stream = device_mod.current_stream()
|
current_stream = device_mod.current_stream()
|
||||||
# Sync the current stream with the copy stream
|
# Sync the current stream with the copy stream
|
||||||
current_stream.wait_stream(stream)
|
current_stream.wait_stream(stream)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user