mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DTensor] clean up _local_shard_size_and_offset (#150650)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150650 Approved by: https://github.com/wanchaol, https://github.com/XilunWu ghstack dependencies: #150490
This commit is contained in:
parent
3532dd4f1e
commit
a8b48ff14c
|
|
@ -506,11 +506,13 @@ class DistTensorRandomOpTest(DTensorTestBase):
|
|||
shard_dim = placement.dim
|
||||
local_shard_list_on_dim[shard_dim] = []
|
||||
for shard_idx_on_dim in range(mesh_dim_size):
|
||||
shard_size, shard_offset = placement._local_shard_size_on_dim(
|
||||
(
|
||||
shard_size,
|
||||
shard_offset,
|
||||
) = placement._local_shard_size_and_offset(
|
||||
dtensor_shape[shard_dim],
|
||||
mesh_dim_size,
|
||||
shard_idx_on_dim,
|
||||
return_offset=True,
|
||||
)
|
||||
local_shard_list_on_dim[shard_dim].append(
|
||||
(shard_offset, shard_size)
|
||||
|
|
|
|||
|
|
@ -92,11 +92,10 @@ class _MaskPartial(Partial):
|
|||
assert self.offset_shape is not None, (
|
||||
"offset_shape needs to be set for _MaskPartial"
|
||||
)
|
||||
local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim(
|
||||
local_shard_size, local_offset_on_dim = Shard._local_shard_size_and_offset(
|
||||
self.offset_shape[self.offset_dim],
|
||||
num_chunks,
|
||||
mesh.get_local_rank(mesh_dim),
|
||||
return_offset=True,
|
||||
)
|
||||
# Build the input mask and save it for the current partial placement
|
||||
# this is so that the output of embedding op can reuse the same partial
|
||||
|
|
|
|||
|
|
@ -329,12 +329,13 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
|||
if isinstance(placement, Shard):
|
||||
mesh_dim_size = mesh.size(idx)
|
||||
shard_dim = placement.dim
|
||||
local_size_on_rank_0[shard_dim] = placement._local_shard_size_on_dim(
|
||||
dtensor_shape[shard_dim],
|
||||
mesh_dim_size,
|
||||
0,
|
||||
return_offset=False,
|
||||
)[0]
|
||||
local_size_on_rank_0[shard_dim], _ = (
|
||||
placement._local_shard_size_and_offset(
|
||||
dtensor_shape[shard_dim],
|
||||
mesh_dim_size,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
from torch.distributed.tensor._ops.utils import prod
|
||||
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ def _gen_transform_infos_non_cached(
|
|||
if i < device_mesh.ndim - 1:
|
||||
# calculate and save the logical shape for this sharding
|
||||
mesh_dim_size = device_mesh.size(mesh_dim=i)
|
||||
local_shard_size, _ = src._local_shard_size_on_dim(
|
||||
local_shard_size, _ = src._local_shard_size_and_offset(
|
||||
current_logical_shape[src.dim],
|
||||
mesh_dim_size,
|
||||
my_coordinate[i],
|
||||
|
|
|
|||
|
|
@ -128,11 +128,10 @@ def compute_local_shape_and_global_offset(
|
|||
assert shard_dim < len(local_shape), (
|
||||
f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
|
||||
)
|
||||
shard_size, shard_offset = placement._local_shard_size_on_dim(
|
||||
shard_size, shard_offset = placement._local_shard_size_and_offset(
|
||||
local_shape[shard_dim],
|
||||
mesh_dim_size,
|
||||
my_coordinate[mesh_dim],
|
||||
return_offset=True,
|
||||
)
|
||||
|
||||
local_shape[shard_dim] = shard_size
|
||||
|
|
|
|||
|
|
@ -111,11 +111,10 @@ def _compute_local_shape_and_global_offset(
|
|||
assert shard_dim < len(local_shape), (
|
||||
f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
|
||||
)
|
||||
shard_size, shard_offset = placement._local_shard_size_on_dim(
|
||||
shard_size, shard_offset = placement._local_shard_size_and_offset(
|
||||
local_shape[shard_dim],
|
||||
mesh_dim_size,
|
||||
my_coordinate[idx],
|
||||
return_offset=True,
|
||||
)
|
||||
|
||||
local_shape[shard_dim] = shard_size
|
||||
|
|
|
|||
|
|
@ -109,32 +109,39 @@ class Shard(Placement):
|
|||
return shard_list, pad_sizes
|
||||
|
||||
@staticmethod
|
||||
def _local_shard_size_on_dim(
|
||||
size_on_dim: int,
|
||||
def _local_shard_size_and_offset(
|
||||
curr_local_size: int,
|
||||
num_chunks: int,
|
||||
rank: int,
|
||||
return_offset: bool = False,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
returns the local shard size and offset on a given tensor dim
|
||||
Given the size of the current local tensor (which may already be sharded on some dimensions),
|
||||
computes the new local shard size and offset given the desired number of chunks
|
||||
(num_chunks is generally equal to the size of the current sharding dim).
|
||||
|
||||
Note: new local shard offset is relative to the current sharded tensor, not the global tensor.
|
||||
See `_utils.compute_local_shape_and_global_offset` for computing global offset.
|
||||
|
||||
Returns (new local shard size, offset)
|
||||
|
||||
"""
|
||||
# Compute the chunk size inline with ``torch.chunk``
|
||||
if size_on_dim % num_chunks == 0:
|
||||
full_chunk_size = size_on_dim // num_chunks
|
||||
return full_chunk_size, full_chunk_size * rank if return_offset else -1
|
||||
if curr_local_size % num_chunks == 0:
|
||||
full_chunk_size = curr_local_size // num_chunks
|
||||
return full_chunk_size, full_chunk_size * rank
|
||||
|
||||
# uneven sharding case
|
||||
full_chunk_size = (size_on_dim + num_chunks - 1) // num_chunks
|
||||
full_chunk_size = (curr_local_size + num_chunks - 1) // num_chunks
|
||||
shard_starting_idx = full_chunk_size * rank
|
||||
|
||||
if size_on_dim < shard_starting_idx:
|
||||
return 0, size_on_dim if return_offset else -1
|
||||
if curr_local_size < shard_starting_idx:
|
||||
return 0, curr_local_size
|
||||
else:
|
||||
local_shard_size = (
|
||||
min(size_on_dim, shard_starting_idx + full_chunk_size)
|
||||
min(curr_local_size, shard_starting_idx + full_chunk_size)
|
||||
- shard_starting_idx
|
||||
)
|
||||
return local_shard_size, shard_starting_idx if return_offset else -1
|
||||
return local_shard_size, shard_starting_idx
|
||||
|
||||
def _shard_tensor(
|
||||
self,
|
||||
|
|
@ -324,7 +331,7 @@ class Shard(Placement):
|
|||
new_tensor = unpad_tensor(new_tensor, self.dim, old_dim_unpad_size) # type: ignore[possibly-undefined]
|
||||
|
||||
if new_dim_padding:
|
||||
local_shard_size_on_new_dim = self._local_shard_size_on_dim(
|
||||
local_shard_size_on_new_dim = self._local_shard_size_and_offset(
|
||||
new_dim_logical_size, num_chunks, my_coordinate[mesh_dim]
|
||||
)[0]
|
||||
new_dim_unpad_size = new_dim_full_chunk_size - local_shard_size_on_new_dim # type: ignore[possibly-undefined]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user