[DTensor] clean up _local_shard_size_and_offset (#150650)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150650
Approved by: https://github.com/wanchaol, https://github.com/XilunWu
ghstack dependencies: #150490
This commit is contained in:
Will Constable 2025-04-09 09:59:57 -07:00 committed by PyTorch MergeBot
parent 3532dd4f1e
commit a8b48ff14c
7 changed files with 35 additions and 28 deletions

View File

@ -506,11 +506,13 @@ class DistTensorRandomOpTest(DTensorTestBase):
shard_dim = placement.dim
local_shard_list_on_dim[shard_dim] = []
for shard_idx_on_dim in range(mesh_dim_size):
shard_size, shard_offset = placement._local_shard_size_on_dim(
(
shard_size,
shard_offset,
) = placement._local_shard_size_and_offset(
dtensor_shape[shard_dim],
mesh_dim_size,
shard_idx_on_dim,
return_offset=True,
)
local_shard_list_on_dim[shard_dim].append(
(shard_offset, shard_size)

View File

@ -92,11 +92,10 @@ class _MaskPartial(Partial):
assert self.offset_shape is not None, (
"offset_shape needs to be set for _MaskPartial"
)
local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim(
local_shard_size, local_offset_on_dim = Shard._local_shard_size_and_offset(
self.offset_shape[self.offset_dim],
num_chunks,
mesh.get_local_rank(mesh_dim),
return_offset=True,
)
# Build the input mask and save it for the current partial placement
# this is so that the output of embedding op can reuse the same partial

View File

@ -329,12 +329,13 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
if isinstance(placement, Shard):
mesh_dim_size = mesh.size(idx)
shard_dim = placement.dim
local_size_on_rank_0[shard_dim] = placement._local_shard_size_on_dim(
dtensor_shape[shard_dim],
mesh_dim_size,
0,
return_offset=False,
)[0]
local_size_on_rank_0[shard_dim], _ = (
placement._local_shard_size_and_offset(
dtensor_shape[shard_dim],
mesh_dim_size,
0,
)
)
from torch.distributed.tensor._ops.utils import prod

View File

@ -73,7 +73,7 @@ def _gen_transform_infos_non_cached(
if i < device_mesh.ndim - 1:
# calculate and save the logical shape for this sharding
mesh_dim_size = device_mesh.size(mesh_dim=i)
local_shard_size, _ = src._local_shard_size_on_dim(
local_shard_size, _ = src._local_shard_size_and_offset(
current_logical_shape[src.dim],
mesh_dim_size,
my_coordinate[i],

View File

@ -128,11 +128,10 @@ def compute_local_shape_and_global_offset(
assert shard_dim < len(local_shape), (
f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
)
shard_size, shard_offset = placement._local_shard_size_on_dim(
shard_size, shard_offset = placement._local_shard_size_and_offset(
local_shape[shard_dim],
mesh_dim_size,
my_coordinate[mesh_dim],
return_offset=True,
)
local_shape[shard_dim] = shard_size

View File

@ -111,11 +111,10 @@ def _compute_local_shape_and_global_offset(
assert shard_dim < len(local_shape), (
f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
)
shard_size, shard_offset = placement._local_shard_size_on_dim(
shard_size, shard_offset = placement._local_shard_size_and_offset(
local_shape[shard_dim],
mesh_dim_size,
my_coordinate[idx],
return_offset=True,
)
local_shape[shard_dim] = shard_size

View File

@ -109,32 +109,39 @@ class Shard(Placement):
return shard_list, pad_sizes
@staticmethod
def _local_shard_size_on_dim(
size_on_dim: int,
def _local_shard_size_and_offset(
curr_local_size: int,
num_chunks: int,
rank: int,
return_offset: bool = False,
) -> tuple[int, int]:
"""
returns the local shard size and offset on a given tensor dim
Given the size of the current local tensor (which may already be sharded on some dimensions),
computes the new local shard size and offset given the desired number of chunks
(num_chunks is generally equal to the size of the current sharding dim).
Note: new local shard offset is relative to the current sharded tensor, not the global tensor.
See `_utils.compute_local_shape_and_global_offset` for computing global offset.
Returns (new local shard size, offset)
"""
# Compute the chunk size inline with ``torch.chunk``
if size_on_dim % num_chunks == 0:
full_chunk_size = size_on_dim // num_chunks
return full_chunk_size, full_chunk_size * rank if return_offset else -1
if curr_local_size % num_chunks == 0:
full_chunk_size = curr_local_size // num_chunks
return full_chunk_size, full_chunk_size * rank
# uneven sharding case
full_chunk_size = (size_on_dim + num_chunks - 1) // num_chunks
full_chunk_size = (curr_local_size + num_chunks - 1) // num_chunks
shard_starting_idx = full_chunk_size * rank
if size_on_dim < shard_starting_idx:
return 0, size_on_dim if return_offset else -1
if curr_local_size < shard_starting_idx:
return 0, curr_local_size
else:
local_shard_size = (
min(size_on_dim, shard_starting_idx + full_chunk_size)
min(curr_local_size, shard_starting_idx + full_chunk_size)
- shard_starting_idx
)
return local_shard_size, shard_starting_idx if return_offset else -1
return local_shard_size, shard_starting_idx
def _shard_tensor(
self,
@ -324,7 +331,7 @@ class Shard(Placement):
new_tensor = unpad_tensor(new_tensor, self.dim, old_dim_unpad_size) # type: ignore[possibly-undefined]
if new_dim_padding:
local_shard_size_on_new_dim = self._local_shard_size_on_dim(
local_shard_size_on_new_dim = self._local_shard_size_and_offset(
new_dim_logical_size, num_chunks, my_coordinate[mesh_dim]
)[0]
new_dim_unpad_size = new_dim_full_chunk_size - local_shard_size_on_new_dim # type: ignore[possibly-undefined]