diff --git a/torch/_utils.py b/torch/_utils.py index bcb2c3cdad6..36142a64547 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -1,4 +1,5 @@ import copyreg +import functools import sys import traceback import warnings @@ -839,3 +840,13 @@ def classproperty(func): # Whether we are compiling with torch.compile or not def is_compiling(): return False + + +@functools.lru_cache(2) +def _get_device_module(device_type: str): + device_module = getattr(torch, device_type, None) + if device_module is None: + raise RuntimeError( + f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'." + ) + return device_module diff --git a/torch/distributed/checkpoint/_fsspec_filesystem.py b/torch/distributed/checkpoint/_fsspec_filesystem.py index b8d1c24d61a..0d37924cd36 100644 --- a/torch/distributed/checkpoint/_fsspec_filesystem.py +++ b/torch/distributed/checkpoint/_fsspec_filesystem.py @@ -18,6 +18,7 @@ import fsspec import torch from fsspec.core import url_to_fs from torch import Tensor +from torch._utils import _get_device_module from torch.distributed._shard._utils import narrow_tensor_by_index from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex @@ -114,7 +115,7 @@ class _OverlappingCpuLoader(_TensorLoader): def __init__( self, resolve_fun: Callable, - stream: Union[None, io.RawIOBase, torch._C._CudaStreamBase] = None, + stream: Union[None, io.RawIOBase, torch.Stream] = None, inflight_threshhold: int = 1_000_000, ): self.resolve_fun = resolve_fun @@ -124,9 +125,11 @@ class _OverlappingCpuLoader(_TensorLoader): self.current_items: collections.deque = collections.deque() self.idx = 0 self.started = False - self.stream = stream or torch.cuda.current_stream() - if self.stream != torch.cuda.current_stream(): - self.stream.wait_stream(torch.cuda.current_stream()) + self.device_type = stream.device_type if stream else torch.device("cuda").type + self.device_module = _get_device_module(self.device_type) + self.stream = stream or self.device_module.current_stream() + if self.stream != self.device_module.current_stream(): + self.stream.wait_stream(self.device_module.current_stream()) @property def _done(self): @@ -143,7 +146,7 @@ class _OverlappingCpuLoader(_TensorLoader): return drained def _refill(self): - with torch.cuda.stream(self.stream): + with self.device_module.stream(self.stream): while ( not self._done and self.in_flight_data < self.inflight_threshhold @@ -151,7 +154,7 @@ class _OverlappingCpuLoader(_TensorLoader): _, obj = self.items[self.idx] self.idx += 1 tensor = self.resolve_fun(obj).detach() - if tensor.is_cuda: + if tensor.device.type == self.device_type: tensor = tensor.to(device="cpu", non_blocking=True) elif tensor.device == torch.device("cpu"): if tensor.storage().size() != tensor.numel(): @@ -232,7 +235,7 @@ def _split_by_size_and_type( def _write_item( - stream: Optional[Union[io.RawIOBase, torch._C._CudaStreamBase]], + stream: Optional[Union[io.RawIOBase, torch.Stream]], data: Union[io.BytesIO, torch.Tensor], write_item: WriteItem, storage_key: str, @@ -294,7 +297,7 @@ def _write_files_from_queue( ) for tensor, write_item in loader.values(): - assert not tensor.is_cuda + assert tensor.is_cpu write_results.append( _write_item(stream, tensor, write_item, storage_key) ) diff --git a/torch/distributed/checkpoint/_sharded_tensor_utils.py b/torch/distributed/checkpoint/_sharded_tensor_utils.py index 8d39be25221..07bbdc9a30f 100644 --- a/torch/distributed/checkpoint/_sharded_tensor_utils.py +++ b/torch/distributed/checkpoint/_sharded_tensor_utils.py @@ -26,7 +26,7 @@ from ._traverse import ( STATE_DICT_ITEM, ) -from .utils import _element_wise_add +from .utils import _element_wise_add, _normalize_device_info # TODO: We need to refactor this code. @@ -83,6 +83,7 @@ def _flatten_sharded_tensors(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: st_meta: ShardedTensorMetadata = copy.deepcopy(value.metadata()) other_rank = 0 if dist.get_rank() > 0 else 1 + device_info = _normalize_device_info(inner_shard.tensor.device.type, 0) # Remove the outer ST shard the inner ST covers for i, shard_md in enumerate(st_meta.shards_metadata): @@ -92,7 +93,7 @@ def _flatten_sharded_tensors(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: # Attribute other rank for the other shards for shard_md in st_meta.shards_metadata: - shard_md.placement = _remote_device(f"rank:{other_rank}/cuda:0") + shard_md.placement = _remote_device(f"rank:{other_rank}/{device_info}") # Add other inner shards from the inner tensor for inner_md in inner_st.metadata().shards_metadata: @@ -104,7 +105,7 @@ def _flatten_sharded_tensors(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: inner_md.shard_offsets, ), shard_sizes=inner_md.shard_sizes, - placement=f"rank:{other_rank}/cuda:0", + placement=f"rank:{other_rank}/{device_info}", ) ) diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py index c6ef58c6550..d23bf275127 100644 --- a/torch/distributed/checkpoint/filesystem.py +++ b/torch/distributed/checkpoint/filesystem.py @@ -39,6 +39,7 @@ from .planner import ( from .utils import _create_file_view from torch.distributed._shard._utils import narrow_tensor_by_index +from torch._utils import _get_device_module __all__ = [ "FileSystemWriter", @@ -126,9 +127,11 @@ class _OverlappingCpuLoader(_TensorLoader): self.current_items: collections.deque = collections.deque() self.idx = 0 self.started = False - self.stream = stream or torch.cuda.current_stream() - if self.stream != torch.cuda.current_stream(): - self.stream.wait_stream(torch.cuda.current_stream()) + self.device_type = stream.device_type if stream else torch.device("cuda").type + self.device_module = _get_device_module(self.device_type) + self.stream = stream or self.device_module.current_stream() + if self.stream != self.device_module.current_stream(): + self.stream.wait_stream(self.device_module.current_stream()) @property def _done(self): @@ -145,7 +148,7 @@ class _OverlappingCpuLoader(_TensorLoader): return drained def _refill(self): - with torch.cuda.stream(self.stream): + with self.device_module.stream(self.stream): while ( not self._done and self.in_flight_data < self.inflight_threshhold @@ -153,7 +156,7 @@ class _OverlappingCpuLoader(_TensorLoader): _, obj = self.items[self.idx] self.idx += 1 tensor = self.resolve_fun(obj).detach() - if tensor.is_cuda: + if tensor.device.type == self.device_type: tensor = tensor.to(device="cpu", non_blocking=True) elif tensor.device == torch.device("cpu"): if tensor.storage().size() != tensor.numel(): @@ -292,7 +295,7 @@ def _write_files_from_queue( ) for tensor, write_item in loader.values(): - assert not tensor.is_cuda + assert tensor.is_cpu write_results.append( _write_item(stream, tensor, write_item, storage_key) ) diff --git a/torch/distributed/checkpoint/optimizer.py b/torch/distributed/checkpoint/optimizer.py index 67e45042d11..0d359aa20ac 100644 --- a/torch/distributed/checkpoint/optimizer.py +++ b/torch/distributed/checkpoint/optimizer.py @@ -38,8 +38,11 @@ from torch.distributed.checkpoint._nested_dict import unflatten_state_dict from torch.distributed.checkpoint.utils import ( _element_wise_add, _element_wise_sub, + _normalize_device_info ) +from torch._utils import _get_device_module + STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]] @@ -49,23 +52,27 @@ __all__ = [ ] -def _gen_rank_device(global_rank: int) -> str: - if torch.cuda.is_available(): - return f"cuda:{global_rank % torch.cuda.device_count()}" +def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str: + if device_type == "cpu": + return "cpu" + device_module = _get_device_module(device_type) + if device_module.is_available(): + return _normalize_device_info(device_type, global_rank % device_module.device_count()) return "cpu" def _create_colwise_spec( pg: Optional[dist.ProcessGroup] = None, ) -> ChunkShardingSpec: + pg_device_type = dist.distributed_c10d._get_pg_default_device(pg).type if pg is None: placements = [ - f"rank:{idx}/{_gen_rank_device(idx)}" + f"rank:{idx}/{_gen_rank_device(idx, pg_device_type)}" for idx in range(dist.get_world_size()) ] else: placements = [ - f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx))}" + f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx), pg_device_type)}" for idx in range(pg.size()) ] return ChunkShardingSpec( @@ -92,14 +99,14 @@ def _is_nested_tensor(val: torch.Tensor) -> bool: return False -def _alloc_tensor(props: TensorProperties, size: Sequence[int]) -> torch.Tensor: +def _alloc_tensor(props: TensorProperties, size: Sequence[int], device_type: str = "cuda") -> torch.Tensor: return torch.empty( size=size, dtype=props.dtype, layout=props.layout, requires_grad=props.requires_grad, pin_memory=props.pin_memory, - device=cast(torch.device, torch.cuda.current_device()), + device=cast(torch.device, _get_device_module(device_type).current_device()), ) @@ -255,15 +262,15 @@ def load_sharded_optimizer_state_dict( metadata = storage_reader.read_metadata() layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict) + dp_pg_device_type = dist.distributed_c10d._get_pg_default_device(dp_pg).type + device_module = _get_device_module(dp_pg_device_type) if dp_pg is None: - sharding_spec = ChunkShardingSpec( - dim=0, - placements=[ - f"rank:{i}/cuda:{i % torch.cuda.device_count()}" - for i in range(dist.get_world_size()) - ], - ) + placements = [] + for i in range(dist.get_world_size()): + device_info = _normalize_device_info(dp_pg_device_type, i % device_module.device_count()) + placements.append(f"rank:{i}/{device_info}") + sharding_spec = ChunkShardingSpec(dim=0, placements=placements) # type: ignore[arg-type] else: sharding_spec = _create_colwise_spec(dp_pg) @@ -282,10 +289,10 @@ def load_sharded_optimizer_state_dict( # value: TensorStorageMetadata if value.size.numel() == 1: - state_dict[key] = _alloc_tensor(value.properties, value.size) + state_dict[key] = _alloc_tensor(value.properties, value.size, dp_pg_device_type) elif dp_pg is None: state_dict[key] = _shard_tensor( - _alloc_tensor(value.properties, value.size), sharding_spec + _alloc_tensor(value.properties, value.size, dp_pg_device_type), sharding_spec ) else: spec_key = key_path[2] @@ -305,7 +312,7 @@ def load_sharded_optimizer_state_dict( local_shards.append( Shard( tensor=_alloc_tensor( - value.properties, shard_md.shard_sizes + value.properties, shard_md.shard_sizes, dp_pg_device_type ), metadata=shard_md, ) diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index 546b0bca2f2..d1105038f08 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -355,6 +355,7 @@ def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> List[int]: def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> List[int]: return [i_a - i_b for i_a, i_b in zip(a, b)] + class _ReaderView(io.IOBase): def __init__(self, base_stream: io.IOBase, offset: int, len: int): super().__init__() @@ -386,6 +387,16 @@ class _ReaderView(io.IOBase): def read(self, size=-1): return self.base_stream.read(size) + def _create_file_view(file: io.IOBase, offset: int, length: int) -> io.IOBase: # FIXME (kumpera) torch.load fails if we wrap with io.BufferedReader return _ReaderView(file, offset, length) + + +def _normalize_device_info(device_type: str, device_id: int) -> str: + """ + Device info normalization. + """ + if device_type == "cpu": + return "cpu" + return f"{device_type}:{device_id}" diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index eb7dc5a546d..c7ab8b96381 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -107,7 +107,7 @@ def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies): with device_mod.stream(stream): output = obj.to(target_device) # synchronize with the copy stream - with torch.cuda.device(target_device.index): + with device_mod.device(target_device.index): current_stream = device_mod.current_stream() # Sync the current stream with the copy stream current_stream.wait_stream(stream)