import collections import dataclasses import io import os import pickle import queue import threading from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path from typing import ( Callable, cast, Dict, IO, Iterable, Iterator, List, Optional, Tuple, Union, ) import torch import torch.distributed as dist from torch import Tensor from torch._utils import _get_device_module from torch.distributed._shard._utils import narrow_tensor_by_index from torch.distributed.checkpoint._checkpointer import _Checkpointer from torch.futures import Future from .metadata import Metadata, MetadataIndex from .planner import ( LoadItemType, LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem, WriteItemType, ) from .storage import StorageReader, StorageWriter, WriteResult from .utils import _create_file_view __all__ = ["FileSystemWriter", "FileSystemReader"] @dataclass class _StorageInfo: """This is the per entry storage info.""" relative_path: str offset: int length: int @dataclass class _StoragePrefix: prefix: str DEFAULT_SUFFIX = ".distcp" class _TensorLoader(ABC): @abstractmethod def add(self, size: int, obj: object) -> None: pass @abstractmethod def start_loading(self) -> None: pass @abstractmethod def values(self) -> Iterator[Tuple[torch.Tensor, object]]: pass class _SerialCpuLoader(_TensorLoader): def __init__(self, resolve_fun: Callable) -> None: self.resolve_fun = resolve_fun self.items: List[Tuple[int, object]] = [] def add(self, size: int, obj: object) -> None: self.items.append((size, obj)) def start_loading(self) -> None: pass def values(self) -> Iterator[Tuple[torch.Tensor, object]]: for _, obj in self.items: tensor = self.resolve_fun(obj).detach() tensor = tensor.cpu() if tensor.storage().size() != tensor.numel(): tensor = tensor.clone() yield ( tensor, obj, ) class _OverlappingCpuLoader(_TensorLoader): def __init__( self, resolve_fun: Callable, stream: Optional[torch.Stream] = None, inflight_threshhold: int = 1_000_000, ) -> None: self.resolve_fun = resolve_fun self.items: List[Tuple[int, object]] = [] self.inflight_threshhold = inflight_threshhold self.in_flight_data = 0 self.current_items: collections.deque = collections.deque() self.idx = 0 self.started = False self.device_type = stream.device_type if stream else torch.device("cuda").type self.device_module = _get_device_module(self.device_type) self.stream = cast( torch.cuda.Stream, stream or self.device_module.current_stream() ) if self.stream != self.device_module.current_stream(): self.stream.wait_stream(self.device_module.current_stream()) @property def _done(self) -> bool: return self.idx >= len(self.items) def _drain(self) -> List[Tuple[torch.Tensor, object]]: drained = [] if self.in_flight_data >= self.inflight_threshhold: self.stream.synchronize() while self.in_flight_data >= self.inflight_threshhold: val = self.current_items.popleft() self.in_flight_data -= val[0].numel() * val[0].element_size() drained.append(val) return drained def _refill(self) -> None: with self.device_module.stream(self.stream): while not self._done and self.in_flight_data < self.inflight_threshhold: _, obj = self.items[self.idx] self.idx += 1 tensor = self.resolve_fun(obj).detach() if tensor.device.type == self.device_type: tensor = tensor.to(device="cpu", non_blocking=True) elif tensor.device == torch.device("cpu"): if tensor.storage().size() != tensor.numel(): # this forces the tensor to be both contiguous and with minimal storage tensor = tensor.clone() self.current_items.append( ( tensor, obj, ) ) self.in_flight_data += tensor.numel() * tensor.element_size() def _finish(self) -> Iterable[Tuple[torch.Tensor, object]]: assert self._done if len(self.current_items) > 0: self.stream.synchronize() return self.current_items def add(self, size: int, obj: object) -> None: if self.started: raise RuntimeError("cannot add items after loading started") self.items.append((size, obj)) def start_loading(self) -> None: if self.started: return self.started = True self.items.sort(key=lambda x: x[0]) self._refill() def values(self) -> Iterator[Tuple[torch.Tensor, object]]: self.start_loading() while not self._done: drained = self._drain() self._refill() yield from drained yield from self._finish() def _item_size(item: WriteItem) -> int: size = 1 assert item.tensor_data is not None # can't use math.prod as PT needs to support older python for s in item.tensor_data.size: size *= s dtype = item.tensor_data.properties.dtype return size * torch._utils._element_size(dtype) def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]: if bins == 1: return [items] bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO] tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO] buckets: List[List[WriteItem]] = [[] for _ in range(bins)] bucket_sizes = [0 for _ in range(bins)] tensor_w.sort(key=_item_size, reverse=True) for i, wi in enumerate(bytes_w): buckets[i % bins].append(wi) for wi in tensor_w: # TODO replace with headq idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0] buckets[idx].append(wi) bucket_sizes[idx] += _item_size(wi) return buckets def _write_item( stream: io.IOBase, data: Union[io.BytesIO, torch.Tensor], write_item: WriteItem, storage_key: str, ) -> WriteResult: offset = stream.tell() if write_item.type == WriteItemType.BYTE_IO: assert isinstance(data, io.BytesIO) stream.write(data.getbuffer()) else: assert isinstance(data, torch.Tensor) assert data.device == torch.device("cpu") torch.save(data, cast(IO[bytes], stream)) length = stream.tell() - offset return WriteResult( index=write_item.index, size_in_bytes=length, storage_data=_StorageInfo(storage_key, offset, length), ) def _write_files_from_queue( file_queue: queue.Queue, result_queue: queue.Queue, planner: SavePlanner, inflight_threshhold: int, use_fsync: bool, thread_count: int, ) -> None: try: while True: file_name, storage_key, write_items = file_queue.get_nowait() loader: _TensorLoader # TODO: Using the OverlappingCpuLoader with multiple threads creates significant # performance degredation, observed as being related to cuda stream syncs. We # should try to fix this and use _OverlappingCpuLoader for all threaded cases if ( thread_count == 1 and torch.cuda.is_available() and inflight_threshhold > 0 ): loader = _OverlappingCpuLoader( planner.resolve_data, inflight_threshhold=inflight_threshhold, ) else: loader = _SerialCpuLoader( planner.resolve_data, ) tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO] for write_item in tensor_w: loader.add(_item_size(write_item), write_item) loader.start_loading() bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO] write_results = [] with file_name.open("wb") as stream: for write_item in bytes_w: data = planner.resolve_data(write_item) write_results.append( _write_item(stream, data, write_item, storage_key) ) for tensor, write_item in loader.values(): assert tensor.is_cpu write_results.append( _write_item(stream, tensor, write_item, storage_key) ) if use_fsync: os.fsync(stream.fileno()) result_queue.put(write_results) except queue.Empty: pass class FileSystemWriter(StorageWriter): """ Basic implementation of StorageWriter using file IO. This implementation makes the following assumptions and simplifications: * The checkpoint path is an empty or non-existing directory. * File creation is atomic The checkpoint consist of one file per write request plus a `.metadata` file with the serialized metadata. """ def __init__( self, path: Union[str, os.PathLike], single_file_per_rank: bool = True, sync_files: bool = True, thread_count: int = 1, per_thread_copy_ahead: int = 10_000_000, ) -> None: """ Initialize the writer pointing to `path`. Args: path: directory where the checkpoint will be written to. single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. sync_files : force files to be synced to permanent storage. Default to True. thread_count: Number of IO threads to use to write. Default to 1. per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. """ super().__init__() if not isinstance(path, Path): path = Path(path) self.path = path self.single_file_per_rank = single_file_per_rank self.sync_files = sync_files self.thread_count = thread_count self.per_thread_copy_ahead = per_thread_copy_ahead def set_up_storage_writer(self, is_coordinator: bool) -> None: pass def prepare_local_plan(self, plan: SavePlan) -> SavePlan: self.path.mkdir(parents=True, exist_ok=True) return plan def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]: new_plans = [ dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_")) for i, plan in enumerate(global_plan) ] return new_plans def write_data( self, plan: SavePlan, planner: SavePlanner, ) -> Future[List[WriteResult]]: storage_plan: _StoragePrefix = plan.storage_data file_count = 0 def gen_file(): nonlocal file_count file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}" file_count += 1 return file_name file_queue: queue.Queue = queue.Queue() if self.single_file_per_rank: for bucket in _split_by_size_and_type(self.thread_count, plan.items): file_name = gen_file() file_queue.put((self.path / file_name, file_name, bucket)) else: for item in plan.items: file_name = gen_file() file_queue.put((self.path / file_name, file_name, [item])) result_queue: queue.Queue = queue.Queue() threads = [] for _ in range(1, self.thread_count): t = threading.Thread( target=_write_files_from_queue, args=( file_queue, result_queue, planner, self.per_thread_copy_ahead, self.sync_files, self.thread_count, ), ) t.start() threads.append(t) _write_files_from_queue( file_queue=file_queue, result_queue=result_queue, planner=planner, inflight_threshhold=self.per_thread_copy_ahead, use_fsync=self.sync_files, thread_count=self.thread_count, ) for t in threads: t.join() res = [] try: while True: res += result_queue.get_nowait() except queue.Empty: pass fut: Future[List[WriteResult]] = Future() fut.set_result(res) return fut def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: storage_md = dict() for wr_list in results: storage_md.update({wr.index: wr.storage_data for wr in wr_list}) metadata.storage_data = storage_md with (self.path / ".metadata.tmp").open("wb") as metadata_file: pickle.dump(metadata, metadata_file) if self.sync_files: os.fsync(metadata_file.fileno()) (self.path / ".metadata.tmp").rename(self.path / ".metadata") class FileSystemReader(StorageReader): def __init__(self, path: Union[str, os.PathLike]) -> None: super().__init__() if not isinstance(path, Path): path = Path(path) self.path = path self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict() def _slice_file(self, file, sinfo: _StorageInfo) -> io.IOBase: return _create_file_view(file, sinfo.offset, sinfo.length) def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: # group requests by file per_file: Dict[str, List[ReadItem]] = dict() for read_item in plan.items: item_md = self.storage_data[read_item.storage_index] path = item_md.relative_path per_file.setdefault(path, []).append(read_item) for relative_path, reqs in per_file.items(): with (self.path / relative_path).open("rb") as file: # TODO sort by offset and cache the reading for req in reqs: item_md = self.storage_data[req.storage_index] file_slice = self._slice_file(file, item_md) if req.type == LoadItemType.BYTE_IO: read_bytes = io.BytesIO(file_slice.read(item_md.length)) read_bytes.seek(0) planner.load_bytes(req, read_bytes) else: tensor = cast( Tensor, torch.load(cast(IO[bytes], file_slice), map_location="cpu"), ) tensor = narrow_tensor_by_index( tensor, req.storage_offsets, req.lengths ) target_tensor = planner.resolve_tensor(req).detach() assert ( target_tensor.size() == tensor.size() ), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" target_tensor.copy_(tensor) planner.commit_tensor(req, target_tensor) fut: Future = Future() fut.set_result(None) return fut # Implementing the abstract function in StorageReader def read_metadata(self) -> Metadata: with (self.path / ".metadata").open("rb") as metadata_file: return pickle.load(metadata_file) def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: self.storage_data = metadata.storage_data assert self.storage_data is not None def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: return plan def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: return global_plan class _FileSystemCheckpointer(_Checkpointer): """An implementation of :py:class:`torch.distributed.checkpoint.checkpointer.Checkpointer` for the file system. Wraps the creation and usage of ``FileSystemWriter`` and ``FileSystemReader``. .. warning:: This feature is experimental and subject to removal/change. """ def __init__( self, path: Union[str, os.PathLike], *, single_file_per_rank: bool = True, sync_files: bool = True, thread_count: int = 1, per_thread_copy_ahead: int = 10_000_000, process_group: Optional[dist.ProcessGroup] = None, coordinator_rank: int = 0, no_dist: bool = False, load_planner: Optional[LoadPlanner] = None, save_planner: Optional[SavePlanner] = None, ) -> None: """Initializes Checkpointing defualts, including ``FileSystemWriter`` and ``FileSystemReader`` Args: path: The directory to store/load checkpoints. single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. sync_files: force files to be synced to permanent storage. Default to True. thread_count: Number of IO threads to use to write. Default to 1. per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. process_group: ProcessGroup to be used for cross-rank synchronization. coordinator_rank: Rank to use to coordinate the checkpoint. rank0 is used by default. no_dist: If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``) loader_planner: Instance of LoadPlanner to use when loading. save_planner: Instance of SavePlanner to use when saving. """ storage_writer = FileSystemWriter( path, single_file_per_rank, sync_files, thread_count, per_thread_copy_ahead ) storage_reader = FileSystemReader(path) super().__init__( storage_writer, storage_reader, process_group=process_group, coordinator_rank=coordinator_rank, no_dist=no_dist, load_planner=load_planner, save_planner=save_planner, )