mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Replace assert statements with explicit if/raise patterns across 20 files: - _optim_utils.py (38 asserts) - _flat_param.py (25 asserts) - _fully_shard/_fsdp_param.py (23 asserts) - sharded_grad_scaler.py (12 asserts) - fully_sharded_data_parallel.py (11 asserts) - wrap.py (10 asserts) - _state_dict_utils.py (9 asserts) - _fully_shard/_fsdp_param_group.py (8 asserts) - _runtime_utils.py (6 asserts) - _init_utils.py (6 asserts) - 10 additional files (16 asserts) This prevents assertions from being disabled with Python -O flag. Fixes partially #164878 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165235 Approved by: https://github.com/albanD
141 lines
4.7 KiB
Python
141 lines
4.7 KiB
Python
# mypy: allow-untyped-defs
|
|
import copy
|
|
import itertools
|
|
import math
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch._utils import _get_device_module
|
|
from torch.distributed import distributed_c10d
|
|
from torch.distributed._shard.sharded_tensor import (
|
|
Shard,
|
|
ShardedTensor,
|
|
ShardedTensorMetadata,
|
|
TensorProperties,
|
|
)
|
|
from torch.distributed._shard.sharding_spec import ShardMetadata
|
|
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
|
|
|
|
|
|
def _get_remote_device_str(rank, device_type, num_devices_per_node):
|
|
if device_type.lower() == "cpu":
|
|
return f"rank:{rank}/{device_type}"
|
|
elif device_type.lower() == "hpu":
|
|
return f"rank:{rank}/{device_type}:{_get_device_module(device_type).current_device()}"
|
|
else:
|
|
return f"rank:{rank}/{device_type}:{rank % num_devices_per_node}"
|
|
|
|
|
|
def _create_chunk_sharded_tensor(
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
world_size: int,
|
|
num_devices_per_node: int,
|
|
pg: dist.ProcessGroup,
|
|
device: Optional[torch.device] = None,
|
|
) -> ShardedTensor:
|
|
"""
|
|
Shard a tensor to chunks along the first dimension. The local rank will gets its
|
|
corresponding chunk as the local shard to create a ShardedTensor.
|
|
"""
|
|
chunks = tensor.chunk(world_size, dim=0)
|
|
if len(chunks) > rank:
|
|
local_shard = chunks[rank].clone()
|
|
offsets = [0 for _ in tensor.size()]
|
|
offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank
|
|
local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)]
|
|
else:
|
|
local_shards = []
|
|
|
|
# Create a ShardedTensor without invoking communication.
|
|
chunk_sizes = [list(chunk.size()) for chunk in chunks]
|
|
dim0_offsets = [0] + list(
|
|
itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes])
|
|
)[:-1]
|
|
offsets = [0] * (len(chunk_sizes[0]) - 1)
|
|
chunk_offsets = [[d0] + offsets for d0 in dim0_offsets]
|
|
device_type = (
|
|
distributed_c10d._get_pg_default_device(pg).type
|
|
if device is None
|
|
else device.type
|
|
)
|
|
placements = [
|
|
_get_remote_device_str(
|
|
dist.get_global_rank(pg, r),
|
|
device_type,
|
|
num_devices_per_node,
|
|
)
|
|
for r in range(len(chunk_sizes))
|
|
]
|
|
if len(chunk_sizes) != len(chunk_offsets) or len(chunk_sizes) != len(placements):
|
|
raise AssertionError(
|
|
f"Expected chunk_sizes, chunk_offsets, and placements to have the same length, "
|
|
f"got {len(chunk_sizes)}, {len(chunk_offsets)}, {len(placements)}"
|
|
)
|
|
shard_metadata = [
|
|
ShardMetadata(offset, size, placement)
|
|
for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements)
|
|
]
|
|
sharded_tensor_metadata = ShardedTensorMetadata(
|
|
shards_metadata=shard_metadata,
|
|
size=tensor.size(),
|
|
tensor_properties=TensorProperties(
|
|
dtype=tensor.dtype,
|
|
layout=tensor.layout,
|
|
requires_grad=False,
|
|
memory_format=torch.contiguous_format,
|
|
pin_memory=tensor.is_pinned(),
|
|
),
|
|
)
|
|
return ShardedTensor._init_from_local_shards_and_global_metadata(
|
|
local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg
|
|
)
|
|
|
|
|
|
def _create_chunk_dtensor(
|
|
tensor: torch.Tensor,
|
|
rank: int,
|
|
device_mesh: DeviceMesh,
|
|
) -> DTensor:
|
|
"""
|
|
Shard a tensor to chunks along the first dimension. The local rank will gets its
|
|
corresponding chunk as the local tensor to create a DTensor.
|
|
"""
|
|
# We need to explicitly call .detach() to return a new tensor detached from the current graph.
|
|
tensor = tensor.detach().clone()
|
|
|
|
# FSDP placements: [Shard(0)]
|
|
# HSDP placements: [Replicate(), Shard(0)]
|
|
replicate_placements = [Replicate() for _ in range(device_mesh.ndim)]
|
|
shard_placements = [Replicate() for _ in range(device_mesh.ndim)]
|
|
shard_placements[-1] = DShard(0) # type: ignore[call-overload]
|
|
|
|
return DTensor.from_local(
|
|
tensor, device_mesh, replicate_placements, run_check=False
|
|
).redistribute(
|
|
placements=shard_placements,
|
|
)
|
|
|
|
|
|
def _all_gather_dtensor(
|
|
tensor: DTensor,
|
|
root_mesh: Optional[DeviceMesh],
|
|
) -> torch.Tensor:
|
|
"""
|
|
All gather a DTensor in its sharded dimension and return the local tensor.
|
|
"""
|
|
if root_mesh != tensor.device_mesh:
|
|
raise AssertionError("The device mesh of a tensor should be a root mesh.")
|
|
|
|
placements = list(copy.deepcopy(tensor.placements))
|
|
# FSDP placements: [Shard(0)] -> [Replicate()]
|
|
# HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
|
|
placements[-1] = Replicate()
|
|
tensor = tensor.redistribute(
|
|
device_mesh=tensor.device_mesh,
|
|
placements=placements,
|
|
)
|
|
|
|
return tensor.to_local()
|