[PT-D][Sharding] Enable ops needed in the transformer model training (#75374)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75374

From the code base of FairSeq and MetaSeq codebase (which is essentially a transformer model), we have found that loads of ops are not supported by sharded tensor. So we now implement a simple version so that we can at least run a transformer example:

Ops include: chuck, transpose, view, mask_fill, dropout, softmax and type_as.

Isolate the common logic of registering simple ops into a function and for future register, we just need to implement at most three functions for a new op.

ghstack-source-id: 155309147

Test Plan: CI

Reviewed By: pritamdamania87

Differential Revision: D35123021

fbshipit-source-id: 660e559fb8b4a910eb63e0586c63ab927873a2ce
(cherry picked from commit 83a87ebf627d863448dfe1019c7c5f7112cc14ab)
This commit is contained in:
Junjie Wang (PyTorch) 2022-05-03 10:09:21 -07:00 committed by PyTorch MergeBot
parent 47e7b12d39
commit 7c44d560ba
14 changed files with 1377 additions and 35 deletions

View File

@ -0,0 +1,89 @@
# Owner(s): ["oncall: distributed"]
import sys
import torch
from torch.distributed._shard import sharded_tensor, _shard_tensor
from torch.testing._internal.common_distributed import (
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
run_tests,
)
from torch.testing._internal.distributed._shard.sharded_tensor import (
TEST_GPU_NUM,
ShardedTensorTestBase,
with_comms,
)
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
generate_chunk_sharding_specs_for_test,
generate_enumerable_sharding_specs_for_test,
)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestShardedTensorChunkOps(ShardedTensorTestBase):
def _compare_chunk_result(self, chunked_list, chunked_st_list):
self.assertEqual(len(chunked_list), len(chunked_st_list))
for idx, chunked_st in enumerate(chunked_st_list):
tensor = chunked_list[idx]
st = _shard_tensor(tensor.contiguous(), chunked_st.sharding_spec())
# _shard_tensor generate sharded tensor with metadata ranked by # of rank.
st._metadata.shards_metadata.sort(
key=lambda x: x.shard_offsets[chunked_st.sharding_spec().dim],
)
self.assertTrue(torch.allclose(chunked_st, st))
def _run_sharded_chunk_test(self, local_tensor_size, shard_spec, chunk_num):
torch.manual_seed(0)
local_tensor = torch.rand(*local_tensor_size).cuda(self.rank)
st_tensor = _shard_tensor(local_tensor.clone().detach(), shard_spec)
local_tensor_chunked = torch.chunk(local_tensor, chunk_num, dim=-1)
chunked_st = torch.chunk(st_tensor, chunk_num, dim=-1)
self._compare_chunk_result(local_tensor_chunked, chunked_st)
# TODO: Add sharded tensor chunk test cases back after making st a subclass of Tensor.
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_chunk(self):
sharding_dims = [0]
specs = []
for dim in sharding_dims:
specs.extend(generate_chunk_sharding_specs_for_test(dim))
for spec in specs:
self._run_sharded_chunk_test([17, 14], spec, 3)
self._run_sharded_chunk_test([17, 15, 20], spec, 5)
self._run_sharded_chunk_test([17, 16], spec, 2)
# Large matrix case.
self._run_sharded_chunk_test([128, 512], spec, 8)
self._run_sharded_chunk_test([1024, 2048], spec, 4)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_chunk_error(self):
chunk_spec = generate_chunk_sharding_specs_for_test(-1)
with self.assertRaisesRegex(
NotImplementedError, "Chunk by sharding dim is not supported."
):
st = sharded_tensor.rand(chunk_spec[0], [17, 24])
torch.chunk(st, 5, dim=-1)
enumerable_spec = generate_enumerable_sharding_specs_for_test()
with self.assertRaisesRegex(
NotImplementedError, "Only ChunkShardingSpec is supported for chunk."
):
st = sharded_tensor.rand(enumerable_spec[0], [10, 10])
torch.chunk(st, 5, dim=-1)
if __name__ == "__main__":
run_tests()

View File

@ -30,14 +30,18 @@ if TEST_WITH_DEV_DBG_ASAN:
class TestShardedTensorElementWiseOps(ShardedTensorTestBase): class TestShardedTensorElementWiseOps(ShardedTensorTestBase):
def _run_sharded_elementwise_ops(self, spec, input_size, op): def _run_sharded_elementwise_ops(
self, spec, input_size, op, reset_seed=None, **kwargs
):
torch.manual_seed(self.rank) torch.manual_seed(self.rank)
st = sharded_tensor.rand(spec, *input_size) st = sharded_tensor.rand(spec, *input_size)
new_st = op(st) reset_seed() if reset_seed else None
new_st = op(st, **kwargs)
local_shard = st.local_tensor() local_shard = st.local_tensor()
new_st_local_shard = new_st.local_tensor() new_st_local_shard = new_st.local_tensor()
reset_seed() if reset_seed else None
self.assertEqual( self.assertEqual(
op(local_shard), op(local_shard, **kwargs),
new_st_local_shard, new_st_local_shard,
) )
@ -67,6 +71,37 @@ class TestShardedTensorElementWiseOps(ShardedTensorTestBase):
self._run_sharded_elementwise_ops(spec, [17, 23], torch.nn.functional.relu) self._run_sharded_elementwise_ops(spec, [17, 23], torch.nn.functional.relu)
self._run_sharded_elementwise_ops(spec, [14, 15], torch.nn.functional.relu) self._run_sharded_elementwise_ops(spec, [14, 15], torch.nn.functional.relu)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_dropout(self):
def _reset_random_seed():
torch.manual_seed(self.rank + 4)
specs = generate_chunk_sharding_specs_for_test(
0
) + generate_chunk_sharding_specs_for_test(1)
for spec in specs:
self._run_sharded_elementwise_ops(
spec,
[12, 17],
torch.nn.functional.dropout,
p=0.4,
reset_seed=_reset_random_seed,
)
self._run_sharded_elementwise_ops(
spec,
[18, 21],
torch.nn.functional.dropout,
p=0.5,
reset_seed=_reset_random_seed,
)
_reset_random_seed()
dropout = torch.nn.Dropout(p=0.8)
self._run_sharded_elementwise_ops(
spec, [17, 23], dropout, reset_seed=_reset_random_seed
)
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
import torch import torch
from torch.distributed._shard import _shard_tensor
import torch.distributed._shard.sharded_tensor as sharded_tensor import torch.distributed._shard.sharded_tensor as sharded_tensor
import torch.distributed as dist import torch.distributed as dist
@ -15,17 +16,19 @@ from torch.testing._internal.common_distributed import (
) )
from torch.testing._internal.distributed._shard.sharded_tensor import ( from torch.testing._internal.distributed._shard.sharded_tensor import (
TEST_GPU_NUM,
ShardedTensorTestBase, ShardedTensorTestBase,
with_comms, with_comms,
) )
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import ( from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
gen_binary_op_func gen_binary_op_func,
generate_chunk_sharding_specs_for_test,
) )
class TestMathOps(ShardedTensorTestBase): class TestMathOps(ShardedTensorTestBase):
@with_comms(init_rpc=False) @with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4) @skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl() @requires_nccl()
def test_basic_math_ops(self): def test_basic_math_ops(self):
ops = ["torch.add", "torch.sub", "torch.mul", "torch.div", "+", "-", "*", "/"] ops = ["torch.add", "torch.sub", "torch.mul", "torch.div", "+", "-", "*", "/"]
@ -83,7 +86,7 @@ class TestMathOps(ShardedTensorTestBase):
@with_comms(init_rpc=False) @with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4) @skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl() @requires_nccl()
def test_math_ops_errors(self): def test_math_ops_errors(self):
spec = ChunkShardingSpec( spec = ChunkShardingSpec(
@ -128,3 +131,57 @@ class TestMathOps(ShardedTensorTestBase):
with self.assertRaisesRegex(TypeError, 'with ChunkShardingSpec supports'): with self.assertRaisesRegex(TypeError, 'with ChunkShardingSpec supports'):
torch.add(st, sharded_rhs) torch.add(st, sharded_rhs)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_bmm(self):
for spec in generate_chunk_sharding_specs_for_test(0):
lhs = torch.rand(15, 4, 5).cuda(self.rank)
rhs = torch.rand(15, 5, 6).cuda(self.rank)
tensor = lhs.bmm(rhs)
st_lhs = _shard_tensor(lhs, spec)
st_rhs = _shard_tensor(rhs, spec)
st_expected = _shard_tensor(tensor, spec)
st_expected._metadata.shards_metadata.sort(
key=lambda x: x.shard_offsets[0],
)
self.assertTrue(torch.allclose(torch.bmm(st_lhs, st_rhs), st_expected))
# TODO: Add back test bases after we make ShardedTensor subclass of Tensor.
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_bmm_errors(self):
specs = generate_chunk_sharding_specs_for_test(0)
st_lhs = sharded_tensor.rand(specs[0], (15, 5, 6))
st_rhs = sharded_tensor.rand(specs[1], (15, 5, 6))
with self.assertRaisesRegex(
NotImplementedError,
'Both st and st2 need to have same placements for bmm',
):
torch.bmm(st_lhs, st_rhs)
for spec in specs:
st_lhs = sharded_tensor.rand(spec, (20, 3))
st_rhs = sharded_tensor.rand(spec, (20, 3))
with self.assertRaisesRegex(
TypeError,
'both st and st2 need to be a 3D ShardedTensor',
):
torch.bmm(st_lhs, st_rhs)
rhs = torch.rand(15, 5, 6).cuda(self.rank)
with self.assertRaisesRegex(
TypeError,
'st2 needs to be a ShardedTensor for torch.bmm',
):
torch.bmm(st_lhs, rhs)
spec.dim = 1
st_lhs = sharded_tensor.rand(spec, (15, 5, 6))
st_rhs = sharded_tensor.rand(spec, (15, 5, 6))
with self.assertRaisesRegex(
NotImplementedError,
'Only support performing bmm on tensors sharded on dim 0 now',
):
torch.bmm(st_lhs, st_rhs)

View File

@ -0,0 +1,263 @@
# Owner(s): ["oncall: distributed"]
import copy
import sys
import torch
from torch.distributed._shard import sharded_tensor, _shard_tensor
from torch.testing._internal.common_distributed import (
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
run_tests,
)
from torch.testing._internal.distributed._shard.sharded_tensor import (
TEST_GPU_NUM,
ShardedTensorTestBase,
with_comms,
)
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
generate_enumerable_sharding_specs_for_test,
)
from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
_chunk_sharding_specs_list_for_test,
)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestShardedTensorMatrixOps(ShardedTensorTestBase):
@with_comms(init_rpc=True)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_tensor_contiguous(self):
specs = _chunk_sharding_specs_list_for_test([0], seed=7)
for spec in specs:
st = sharded_tensor.rand(spec, 10, 22, 5, init_rrefs=True)
st = st.transpose(1, 0)
st = st.contiguous()
self.assertTrue(st.is_contiguous())
self.assertTrue(st.local_tensor().is_contiguous())
@with_comms(init_rpc=True)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_tensor_type_as(self):
specs = _chunk_sharding_specs_list_for_test([0], seed=7)
for spec in specs:
st = sharded_tensor.rand(
spec, 16, 30, 5, init_rrefs=True, dtype=torch.double
)
st_2 = sharded_tensor.rand(
spec, 16, 30, 5, init_rrefs=True, dtype=torch.float
)
st_3 = st.type_as(st_2)
self.assertEqual(torch.float, st_3.dtype)
self.assertEqual(torch.float, st_3.local_tensor().dtype)
st_3 = st.type_as(torch.zeros(10).type(torch.BoolTensor).cuda())
self.assertEqual(torch.bool, st_3.dtype)
self.assertEqual(torch.bool, st_3.local_tensor().dtype)
@with_comms(init_rpc=True)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_tensor_transpose(self):
specs = _chunk_sharding_specs_list_for_test([0, 1, 2], seed=7)
for spec in specs:
tensor = torch.rand(15, 27, 16).cuda(self.rank)
tensor_t = tensor.transpose(0, 1).contiguous()
spec_n = copy.deepcopy(spec)
if spec_n.dim in (0, 1):
spec_n.dim = 1 - spec_n.dim
st_expected = _shard_tensor(tensor_t, spec_n)
st_expected._metadata.shards_metadata.sort(
key=lambda x: x.shard_offsets[0],
)
self.assertTrue(
torch.allclose(
torch.transpose(_shard_tensor(tensor, spec), 0, 1), st_expected
)
)
tensor_t = torch.transpose(tensor, 1, 2).contiguous()
spec_n = copy.deepcopy(spec)
if spec_n.dim in (1, 2):
spec_n.dim = 3 - spec_n.dim
st_expected = _shard_tensor(tensor_t, spec_n)
st_expected._metadata.shards_metadata.sort(
key=lambda x: x.shard_offsets[0],
)
self.assertTrue(
torch.allclose(_shard_tensor(tensor, spec).transpose(1, 2), st_expected)
)
@with_comms(init_rpc=True)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_tensor_transpose_error(self):
enumerable_spec = generate_enumerable_sharding_specs_for_test()[0]
st = sharded_tensor.rand(
enumerable_spec, 10, 10, init_rrefs=True, dtype=torch.double
)
with self.assertRaisesRegex(
NotImplementedError,
"Only ChunkShardingSpec supported for 'transpose'",
):
st.transpose(1, 0)
@with_comms(init_rpc=True)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_tensor_softmax(self):
specs = _chunk_sharding_specs_list_for_test([0, 2], seed=17)
for spec in specs:
tensor = torch.rand(15, 27, 16).cuda(self.rank)
tensor_n = torch.nn.functional.softmax(tensor, dim=1, dtype=torch.float32)
st_expected = _shard_tensor(tensor_n, spec)
st_expected._metadata.shards_metadata.sort(
key=lambda x: x.shard_offsets[0],
)
self.assertTrue(
torch.allclose(
torch.nn.functional.softmax(
_shard_tensor(tensor, spec), dim=1, dtype=torch.float32
),
st_expected,
)
)
@with_comms(init_rpc=True)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_tensor_softmax_error(self):
specs = _chunk_sharding_specs_list_for_test([0, 2], seed=17)
for spec in specs:
st = sharded_tensor.rand(
spec, 16, 30, 5, init_rrefs=True, dtype=torch.double
)
with self.assertRaisesRegex(
NotImplementedError,
"Only support performing softmax on non-sharding dim now.",
):
torch.nn.functional.softmax(
st, dim=st.sharding_spec().dim, dtype=torch.float32
)
def _test_masked_fill_with_sizes(self, mask_size, broadcast_style=False):
specs = _chunk_sharding_specs_list_for_test([0, 1, 2], seed=7)
for spec in specs:
tensor = torch.rand(35, 17, 26).cuda(self.rank)
mask = torch.randint(0, 2, mask_size).type(torch.BoolTensor).cuda(self.rank)
if broadcast_style:
mask = mask.unsqueeze(1)
tensor_m = tensor.masked_fill(mask, 25.0)
st_expected = _shard_tensor(tensor_m, spec)
st_expected._metadata.shards_metadata.sort(
key=lambda x: x.shard_offsets[0],
)
self.assertTrue(
torch.allclose(
_shard_tensor(tensor, spec).masked_fill(mask, 25.0),
st_expected,
)
)
@with_comms(init_rpc=True)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_tensor_masked_fill(self):
self._test_masked_fill_with_sizes((35, 17, 26))
self._test_masked_fill_with_sizes((17, 26))
self._test_masked_fill_with_sizes((35, 26), broadcast_style=True)
self._test_masked_fill_with_sizes((26,))
@with_comms(init_rpc=True)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_tensor_masked_fill_error(self):
specs = _chunk_sharding_specs_list_for_test([1, 2], seed=7)
for spec in specs:
st = sharded_tensor.rand(
spec, 35, 17, 26, init_rrefs=True, dtype=torch.double
)
mask = (
torch.randint(0, 2, (2, 35, 17, 26))
.type(torch.BoolTensor)
.cuda(self.rank)
)
with self.assertRaisesRegex(
ValueError,
"mask dim must not greater than the dim of the sharded tensor.",
):
st.masked_fill(mask, 25.0)
mask = torch.randint(0, 2, (16, 26)).type(torch.BoolTensor).cuda(self.rank)
with self.assertRaisesRegex(
ValueError,
"The size of mask 0 must match the size of sharded tensor 1 "
"at non-singleton dimension 0",
):
st.masked_fill(mask, 25.0)
@with_comms(init_rpc=True)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_tensor_view(self):
specs = _chunk_sharding_specs_list_for_test([0, 0], seed=10)
for spec in specs:
tensor = torch.rand(16, 35, 26).cuda(self.rank)
tensor_v = tensor.view(16, 35, 26).view(4, 4, 35, 26)
st_expected = _shard_tensor(tensor_v, spec)
st_expected._metadata.shards_metadata.sort(
key=lambda x: x.shard_offsets[0],
)
self.assertTrue(
torch.allclose(
_shard_tensor(tensor, spec).view(4, 4, 35, 26),
st_expected,
)
)
st_expected = _shard_tensor(tensor, spec)
st_expected._metadata.shards_metadata.sort(
key=lambda x: x.shard_offsets[0],
)
self.assertTrue(
torch.allclose(
_shard_tensor(tensor_v, spec).view(16, 35, 26),
st_expected,
)
)
@with_comms(init_rpc=True)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_tensor_view_error(self):
for spec in _chunk_sharding_specs_list_for_test([2], seed=7):
st = sharded_tensor.rand(
spec, 35, 17, 26, init_rrefs=True, dtype=torch.double
)
with self.assertRaisesRegex(
NotImplementedError,
"Shape having dim 2 is not supported "
"for sharded tensor sharded on dim 2.",
):
st.view(35 * 17, 26)
with self.assertRaisesRegex(
ValueError,
r"Shape '\[5, 7, 35, 17, 26\]' is invalid for sharded tensor size 15470.",
):
st.view(5, 7, 35, 17, 26)
with self.assertRaisesRegex(
ValueError,
"Only one dimension can be inferred for sharded view op.",
):
st.view(5, 7, -1, -1)
if __name__ == "__main__":
run_tests()

View File

@ -206,6 +206,7 @@ WINDOWS_BLOCKLIST = [
"distributed/_shard/sharded_tensor/test_megatron_prototype", "distributed/_shard/sharded_tensor/test_megatron_prototype",
"distributed/_shard/sharded_tensor/test_sharded_tensor", "distributed/_shard/sharded_tensor/test_sharded_tensor",
"distributed/_shard/sharded_tensor/test_sharded_tensor_reshard", "distributed/_shard/sharded_tensor/test_sharded_tensor_reshard",
"distributed/_shard/sharded_tensor/ops/test_chunk",
"distributed/_shard/sharded_tensor/ops/test_elementwise_ops", "distributed/_shard/sharded_tensor/ops/test_elementwise_ops",
"distributed/_shard/sharded_tensor/ops/test_embedding", "distributed/_shard/sharded_tensor/ops/test_embedding",
"distributed/_shard/sharded_tensor/ops/test_embedding_bag", "distributed/_shard/sharded_tensor/ops/test_embedding_bag",
@ -213,6 +214,7 @@ WINDOWS_BLOCKLIST = [
"distributed/_shard/sharded_tensor/ops/test_init", "distributed/_shard/sharded_tensor/ops/test_init",
"distributed/_shard/sharded_tensor/ops/test_linear", "distributed/_shard/sharded_tensor/ops/test_linear",
"distributed/_shard/sharded_tensor/ops/test_math_ops", "distributed/_shard/sharded_tensor/ops/test_math_ops",
"distributed/_shard/sharded_tensor/ops/test_matrix_ops",
"distributed/_shard/sharding_spec/test_sharding_spec", "distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharded_optim/test_sharded_optim", "distributed/_shard/sharded_optim/test_sharded_optim",
"distributed/_shard/test_partial_tensor", "distributed/_shard/test_partial_tensor",
@ -229,6 +231,7 @@ ROCM_BLOCKLIST = [
"distributed/_shard/sharded_tensor/test_megatron_prototype", "distributed/_shard/sharded_tensor/test_megatron_prototype",
"distributed/_shard/sharded_tensor/test_sharded_tensor", "distributed/_shard/sharded_tensor/test_sharded_tensor",
"distributed/_shard/sharded_tensor/test_sharded_tensor_reshard", "distributed/_shard/sharded_tensor/test_sharded_tensor_reshard",
"distributed/_shard/sharded_tensor/ops/test_chunk",
"distributed/_shard/sharded_tensor/ops/test_elementwise_ops", "distributed/_shard/sharded_tensor/ops/test_elementwise_ops",
"distributed/_shard/sharded_tensor/ops/test_embedding", "distributed/_shard/sharded_tensor/ops/test_embedding",
"distributed/_shard/sharded_tensor/ops/test_embedding_bag", "distributed/_shard/sharded_tensor/ops/test_embedding_bag",
@ -236,6 +239,7 @@ ROCM_BLOCKLIST = [
"distributed/_shard/sharded_tensor/ops/test_init", "distributed/_shard/sharded_tensor/ops/test_init",
"distributed/_shard/sharded_tensor/ops/test_linear", "distributed/_shard/sharded_tensor/ops/test_linear",
"distributed/_shard/sharded_tensor/ops/test_math_ops", "distributed/_shard/sharded_tensor/ops/test_math_ops",
"distributed/_shard/sharded_tensor/ops/test_matrix_ops",
"distributed/_shard/sharding_spec/test_sharding_spec", "distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharded_optim/test_sharded_optim", "distributed/_shard/sharded_optim/test_sharded_optim",
"distributed/_shard/test_partial_tensor", "distributed/_shard/test_partial_tensor",

View File

@ -1,5 +1,7 @@
import torch.distributed._shard.sharded_tensor._ops.chunk
import torch.distributed._shard.sharded_tensor._ops.elementwise_ops import torch.distributed._shard.sharded_tensor._ops.elementwise_ops
import torch.distributed._shard.sharded_tensor._ops.math_ops import torch.distributed._shard.sharded_tensor._ops.math_ops
import torch.distributed._shard.sharded_tensor._ops.matrix_ops
from .binary_cmp import equal, allclose from .binary_cmp import equal, allclose
from .embedding import sharded_embedding from .embedding import sharded_embedding

View File

@ -1,9 +1,16 @@
# coding=utf-8 # coding=utf-8
import functools
from typing import List from typing import List
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.distributed._shard.sharding_spec as shard_spec
from torch.distributed._shard.sharded_tensor import (
sharded_op_impl,
Shard,
ShardedTensor,
)
from torch.distributed._shard.sharding_spec._internals import ( from torch.distributed._shard.sharding_spec._internals import (
get_split_size, get_split_size,
get_chunked_dim_size, get_chunked_dim_size,
@ -14,6 +21,171 @@ from torch.distributed.nn.functional import (
) )
def _chunk_sharding_spec_check(spec, op):
"""
For the given op implementation check if the sharding spec is ChunkShardingSpec.
"""
if not isinstance(spec, shard_spec.ChunkShardingSpec):
raise NotImplementedError(
f"Only ChunkShardingSpec supported for '{op.__name__}'."
)
def _sharded_op_common(op, early_stop_func, extra_check):
"""
Inject sharded tensor op registration with common logics executed before
different behaviors are done on either local shards or a local tensor.
Example::
>>> op = torch.transpose
>>> @sharded_op_impl(op)
>>> @_sharded_op_common(op, early_stop_func, extra_check)
>>> def sharded_tensor_op(types, args, kwargs, process_group):
>>> ....
>>>
>>> st = sharded_tensor.rand(32, 16)
>>> st.transpose(1, 2)
>>> # This will call '_sharded_op_common'
Args:
op: The op to be registered and applied to all shards of the st.
early_stop_func (Callable, optional): the func for early stop.
Default: if ``None``, no early stop.
extra_check (Callable, optional): the func for extra condition check.
Default: if ``None``, no extra check.
Return:
func (Callable): Torch function for which we want to provide a sharded
implementation (ex: torch.transpose)
"""
def decorator_sharded_func(wrapped_func):
@functools.wraps(wrapped_func)
def wrapper(types, args=(), kwargs=None, pg=None):
if len(args) == 0:
raise ValueError(f" No input for '{op.__name__}'!")
# Validate types
st = args[0]
if not isinstance(st, ShardedTensor):
raise TypeError(
f"torch function '{op.__name__}', with args: {args} and "
f"kwargs: {kwargs} are called for non ShardedTensor!"
)
if kwargs is None:
kwargs = {}
if extra_check:
extra_check(*args, **kwargs)
if early_stop_func:
early_stop = early_stop_func(*args, **kwargs)
if early_stop:
return st
return wrapped_func(types, args, kwargs, pg)
return wrapper
return decorator_sharded_func
def _register_sharded_op_on_local_shards(
op, early_stop_func=None, extra_check=None, customized_func=None
):
"""
Handles ``__torch_function__`` dispatch for ops which are performed on
each shard of the sharded tensor such as elementwise op like
``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
For more complicated ops, a customized func can be used to generate
the new shards and sharded tensor size.
Args:
op: The op to be registered and applied to all shards of the st.
early_stop_func (Callable, optional): the func for early stop.
Default: if ``None``, no early stop.
extra_check (Callable, optional): the func for extra condition check.
Default: if ``None``, no extra check.
customized_func (Callable, optional): the func for customized logic
to generate new shards and sharded tensor size.
Default: if ``None``, we simply lower to the real op call with
all local shards of the st.
Return:
func (Callable): registered implementation for sharded op for
``__torch_function__`` dispatch.
"""
@sharded_op_impl(op)
@_sharded_op_common(op, early_stop_func, extra_check)
def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
st = args[0]
st_metadata = st.metadata()
local_shards = st.local_shards()
local_shards_new = []
if customized_func:
local_shards_new, st_metadata = customized_func(args, kwargs, pg)
else:
for local_shard in local_shards:
args = (local_shard.tensor, *args[1:])
local_shards_new.append(
Shard(op(*args, **kwargs), local_shard.metadata)
)
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards_new,
st_metadata,
process_group=pg,
init_rrefs=st._init_rrefs,
)
def _register_sharded_op_on_local_tensor(
op, early_stop_func=None, extra_check=None, customized_func=None
):
"""
Handles ``__torch_function__`` dispatch for ops which are performed on
the single local tensor of the sharded tensor such as op like
``torch.nn.functional.softmax`` or ``torch.Tensor.view``.
For more complicated ops, a customized func can be used to generate
the new local tensor, sharding spec and sharded tensor size.
Args:
op: The op to be registered and applied to all shards of the st.
early_stop_func (Callable, optional): the func for early stop.
Default: if ``None``, no early stop.
extra_check (Callable, optional): the func for extra condition check.
Default: if ``None``, no extra check.
customized_func (Callable, optional): the func for customized logic
to generate the new local tensor, sharding spec and sharded tensor size.
Default: if ``None``, we simply lower to the real op call with
the single local tensor of the st.
Return:
func (Callable): registered implementation for sharded op for
``__torch_function__`` dispatch.
"""
@sharded_op_impl(op)
@_sharded_op_common(op, early_stop_func, extra_check)
def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
st = args[0]
sharding_spec = st.sharding_spec()
_chunk_sharding_spec_check(sharding_spec, op)
if len(st.local_shards()) != 1:
raise TypeError(
f"torch function '{op.__name__}', with args: {args} and "
f"kwargs: {kwargs} only supported for single local tensor!"
)
st_size = st.size()
if customized_func:
local_tensor, sharding_spec, st_size = customized_func(args, kwargs, pg)
else:
args = (st.local_tensor(), *args[1:])
local_tensor = op(*args, **kwargs)
return ShardedTensor._init_from_local_tensor(
local_tensor.contiguous(),
sharding_spec,
st_size, # type: ignore[arg-type]
process_group=pg,
init_rrefs=st._init_rrefs,
)
def _handle_col_wise_sharding_base( def _handle_col_wise_sharding_base(
op_func, op_func,
col_dim, col_dim,

View File

@ -0,0 +1,63 @@
import torch
from torch.distributed._shard.sharded_tensor import (
sharded_op_impl,
ShardedTensor,
)
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
def register_chunk_op(op):
@sharded_op_impl(op)
def sharded_chunk(types, args=(), kwargs=None, pg=None):
"""
Handles ``__torch_function__`` dispatch for the chunk op.
If we chunk by the non-sharding dim, we just directly chunk the
local tensor and create a list of sharded tensor based on them.
Warnings: Chunk by the sharding dim is not supported.
Args: same as ``torch.chunk``.
Return:
List[ShardedTensor]: Chunk results as a list of ShardedTensor.
"""
st = args[0]
chunk_num = args[1]
dim = kwargs.get("dim")
dim = dim if dim else 0
# Validate types
if not isinstance(st, ShardedTensor):
raise TypeError(
f"torch function '{op.__name__}', with args: {args} and "
f"kwargs: {kwargs} are called for non ShardedTensor!"
)
spec = st.sharding_spec()
if not isinstance(spec, ChunkShardingSpec):
raise NotImplementedError("Only ChunkShardingSpec is supported for chunk.")
if spec.dim == dim or st.dim() + spec.dim == dim or st.dim() + dim == spec.dim: # type: ignore[operator]
raise NotImplementedError("Chunk by sharding dim is not supported.")
local_tensor = st.local_tensor()
st_size = st.size()
dim = dim if dim > 0 else st.dim() + dim
results = []
for chunk_tensor in local_tensor.chunk(chunk_num, dim=dim):
new_st_size = (*st_size[:dim], chunk_tensor.size(dim), *st_size[dim + 1 :]) # type: ignore[index]
results.append(
ShardedTensor._init_from_local_tensor(
chunk_tensor.contiguous(),
st.sharding_spec(),
new_st_size,
process_group=pg,
)
)
return results
chunk_ops = [
torch.chunk,
torch.Tensor.chunk,
]
for op in chunk_ops:
register_chunk_op(op)

View File

@ -0,0 +1,35 @@
import torch
from torch.distributed._shard.sharded_tensor import (
sharded_op_impl,
)
def register_default_op(op):
@sharded_op_impl(op)
def tensor_default_op(types, args=(), kwargs=None, pg=None):
"""
Handles ``__torch_function__`` dispatch for the default tensor ops that
behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or
``torch.Tensor.dtype``. We simply lower to the real op call with
DisableTorchFunction context like ``torch.Tensor.__torch_function__``
to avoid recursions.
"""
if kwargs is None:
kwargs = {}
with torch._C.DisableTorchFunction():
return op(*args, **kwargs)
# Tensor properties access
register_default_op(torch.Tensor.requires_grad.__get__) # type: ignore[attr-defined]
register_default_op(torch.Tensor.shape.__get__) # type: ignore[attr-defined]
register_default_op(torch.Tensor.dtype.__get__) # type: ignore[attr-defined]
register_default_op(torch.Tensor.layout.__get__) # type: ignore[attr-defined]
register_default_op(torch.Tensor.size)
register_default_op(torch.Tensor.dim)
register_default_op(torch.Tensor.ndim.__get__) # type: ignore[attr-defined]
register_default_op(torch.Tensor.is_contiguous)
register_default_op(torch.Tensor.contiguous)
# __reduce_ex__ to dispatch to get_state/set_state
register_default_op(torch.Tensor.__reduce_ex__)

View File

@ -1,30 +1,10 @@
import torch import torch
from torch.distributed._shard.sharded_tensor import (
sharded_op_impl, from ._common import (
Shard, _register_sharded_op_on_local_shards,
ShardedTensor,
) )
def register_elementwise_op(op): _register_sharded_op_on_local_shards(torch.nn.functional.gelu)
@sharded_op_impl(op) _register_sharded_op_on_local_shards(torch.nn.functional.relu)
def elementwise_op(types, args=(), kwargs=None, pg=None): _register_sharded_op_on_local_shards(torch.nn.functional.dropout)
"""
Handles ``__torch_function__`` dispatch for the elementwise op such
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
This method computes on either a normal tensor or a sharded tensor.
"""
input = args[0]
# Validate types
if not isinstance(input, ShardedTensor):
raise TypeError("input needs to be a ShardedTensor")
local_shards_new = []
for local_shard in input.local_shards():
local_shards_new.append(Shard(op(local_shard.tensor), local_shard.metadata))
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards_new, input.metadata(), process_group=pg
)
register_elementwise_op(torch.nn.functional.gelu)
register_elementwise_op(torch.nn.functional.relu)

View File

@ -9,6 +9,12 @@ from torch.distributed._shard.replicated_tensor import ReplicatedTensor
from torch.distributed._shard._utils import narrow_tensor from torch.distributed._shard._utils import narrow_tensor
from ._common import (
_chunk_sharding_spec_check,
_register_sharded_op_on_local_tensor,
)
def register_math_op(op): def register_math_op(op):
@sharded_op_impl(op) @sharded_op_impl(op)
def binary_math_op(types, args=(), kwargs=None, pg=None): def binary_math_op(types, args=(), kwargs=None, pg=None):
@ -120,3 +126,70 @@ binary_ops = [
for op in binary_ops: for op in binary_ops:
register_math_op(op) register_math_op(op)
def sharded_bmm_check(*args, **kwargs):
"""
Perform extra checks for the sharded_bmm op, for example, st2 needs to
be a sharded tensor and both tensors need to sharded by dim 0, etc.
Args: same as ``torch.bmm``.
Return: None
"""
if len(args) < 2:
raise TypeError("Needs two tensors to perform torch.bmm.")
st = args[0]
st2 = args[1]
# Validate types
if not isinstance(st2, ShardedTensor):
raise TypeError("st2 needs to be a ShardedTensor for torch.bmm.")
_chunk_sharding_spec_check(st2.sharding_spec(), torch.bmm)
if st.dim() != 3 or st2.dim() != 3:
raise TypeError("both st and st2 need to be a 3D ShardedTensor")
if (
st.sharding_spec().dim != st2.sharding_spec().dim # type: ignore[attr-defined]
or st.sharding_spec().dim != 0
):
raise NotImplementedError(
"Only support performing bmm on tensors sharded on dim 0 now."
)
if st.sharding_spec().placements != st2.sharding_spec().placements: # type: ignore[attr-defined]
raise NotImplementedError(
"Both st and st2 need to have same placements for bmm."
)
def sharded_bmm(args, kwargs, pg):
"""
Handles ``__torch_function__`` dispatch for the sharded_bmm op.
Warning: For now we only supports the case when both tensors are sharded
by dim 0 so that no local communication.
Args: same as ``torch.bmm``.
Return:
local_tensor (Tensor): New local tensor to build the sharded tensor.
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
sharding spec of the new sharded tensor.
new_st_size (torch.Size): Size of the new sharded tensor.
"""
st = args[0]
st2 = args[1]
local_tensor = torch.bmm(st.local_tensor(), st2.local_tensor())
new_st_size = (*st.size()[:-1], st2.size(-1))
return local_tensor, st.sharding_spec(), new_st_size
_register_sharded_op_on_local_tensor(
torch.Tensor.bmm,
extra_check=sharded_bmm_check,
customized_func=sharded_bmm,
)
_register_sharded_op_on_local_tensor(
torch.bmm,
extra_check=sharded_bmm_check,
customized_func=sharded_bmm,
)

View File

@ -0,0 +1,324 @@
import copy
import math
import torch
import torch.distributed as dist
from torch.distributed._shard.sharded_tensor import (
Shard,
ShardedTensor,
)
from ._common import (
_chunk_sharding_spec_check,
_register_sharded_op_on_local_tensor,
_register_sharded_op_on_local_shards,
)
def sharded_type_as_check(*args, **kwargs):
"""
Perform extra checks for the sharded_type_as op such as the input needs to
be either a Tensor or ShardedTensor.
Args: same as ``torch.Tensor.type_as``.
Return: None
"""
if len(args) < 2:
raise ValueError("Needs to give a tensor to cast type as!")
if not isinstance(args[1], torch.Tensor) and not isinstance(args[1], ShardedTensor):
raise ValueError("Needs to give a Tensor or ShardedTensor to cast type as!")
def same_dtype(*args, **kwargs):
"""
When the dtype is the same, return the original ShardedTensor.
Args: same as ``torch.Tensor.type_as``.
Return (bool): Whether to return early or not.
"""
return args[0].dtype == args[1].dtype
def sharded_type_as(args, kwargs, pg):
"""
Handles ``__torch_function__`` dispatch for the ``torch.Tensor.type_as`` op.
Args: same as ``torch.Tensor.type_as``.
Return:
new_local_shards (List[Shard]): Local shards for the new sharded tensor.
st_meta (ShardedTensorMetadata): Metadata of the new sharded tensor.
"""
st = args[0]
tensor = args[1]
if isinstance(tensor, ShardedTensor):
tensor = tensor.local_tensor()
new_local_shards = []
for shard in st.local_shards():
new_local_shards.append(Shard(shard.tensor.type_as(tensor), shard.metadata))
st_meta = copy.deepcopy(st._metadata)
st_meta.tensor_properties.dtype = tensor.dtype
return new_local_shards, st_meta
_register_sharded_op_on_local_shards(
torch.Tensor.type_as,
early_stop_func=same_dtype,
extra_check=sharded_type_as_check,
customized_func=sharded_type_as,
)
def transpose_same_dim(*args, **kwargs):
"""
When the dim0 and dim1 of transpose are the same, return the original ShardedTensor.
Args: same as ``torch.Tensor.transpose``.
Return (bool): Whether to return early or not.
"""
return args[1] == args[2]
def sharded_transpose_check(*args, **kwargs):
"""
Perform extra checks for the sharded_transpose op such as the input needs to
be at least 2 and the sharding spec needs to be a ChunkShardingSpec.
Args: same as ``torch.Tensor.type_as``.
Return: None
"""
if len(args) < 3:
raise ValueError("Needs at least two dimensions for transpose op!")
_chunk_sharding_spec_check(args[0].sharding_spec(), torch.Tensor.transpose)
def sharded_transpose(args, kwargs, pg):
"""
Handles ``__torch_function__`` dispatch for the ``torch.Tensor.transpose`` op.
Returns a new sharded tensor with the given dimensions transposed.
During the transpose, we keep the original shading dim, if the sharding
dim is not neither dim0 nor dim1. Otherwise, we will swap the sharding
dim with the other input of transpose.
Args: (same as ``torch.Tensor.transpose``.)
dim0 (Int): the first dimension to be transposed.
dim1 (Int): the second dimension to be transposed.
Returns:
new_local_shards (List[Shard]): Local shards for the new sharded tensor.
st_meta (ShardedTensorMetadata): Metadata of the new sharded tensor.
"""
def _swap_meta_data(data, idx0, idx1):
"""
Swap the item at idx0 and idx1 in the data list.
"""
data[idx0], data[idx1] = data[idx1], data[idx0]
st = args[0]
dim0 = args[1]
dim1 = args[2]
sharding_spec = copy.deepcopy(st.sharding_spec())
if sharding_spec.dim == dim0:
sharding_spec.dim = dim1
elif sharding_spec.dim == dim1:
sharding_spec.dim = dim0
st_size = list(st.size())
_swap_meta_data(st_size, dim0, dim1)
local_tensor = st.local_tensor().transpose(dim0, dim1).contiguous()
return local_tensor, sharding_spec, tuple(st_size)
_register_sharded_op_on_local_tensor(
torch.transpose,
early_stop_func=transpose_same_dim,
extra_check=sharded_transpose_check,
customized_func=sharded_transpose,
)
_register_sharded_op_on_local_tensor(
torch.Tensor.transpose,
early_stop_func=transpose_same_dim,
extra_check=sharded_transpose_check,
customized_func=sharded_transpose,
)
def sharded_softmax_check(*args, **kwargs):
"""
Perform extra checks for ``torch.Tensor.softmax`` op for now we don't support
doing softmax on the sharding dim.
Args: same as ``torch.Tensor.softmax``.
Return: None
"""
st = args[0]
dim = kwargs.get("dim")
dim = dim if dim is not None else 1 # If no dim specified, softmax use 1 as dim.
if dim == st.sharding_spec().dim:
raise NotImplementedError(
"Only support performing softmax on non-sharding dim now."
)
_register_sharded_op_on_local_tensor(
torch.nn.functional.softmax,
extra_check=sharded_softmax_check,
)
def sharded_masked_fill_check(*args, **kwargs):
"""
Perform extra checks for the ``torch.Tensor.masked_fill`` op.
Ensure the mask size is broadcastable with the size of
the sharded tensor.
Args: same as ``torch.Tensor.masked_fill``.
Return: None
"""
st = args[0]
mask = args[1]
if st.dim() < mask.dim():
raise ValueError(
"mask dim must not greater than the dim of the sharded tensor."
)
for idx in range(-1, -mask.dim() - 1, -1):
if mask.size(idx) != st.size(idx) and mask.size(idx) != 1:
raise ValueError(
f"The size of mask {mask.dim() + idx} must match the size of "
f"sharded tensor {st.dim() + idx} at non-singleton dimension {mask.dim() + idx}"
)
def sharded_masked_fill(args, kwargs, pg):
"""
Handles ``__torch_function__`` dispatch for the ``torch.Tensor.masked_fill`` op.
We first narrow down the mask to the size of local tensor if the mask
contains the sharding dim and then apply the mask to the local tensor.
Args: same as ``torch.Tensor.masked_fill``.
Return:
local_tensor (Tensor): New local tensor to build the sharded tensor.
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
sharding spec of the new sharded tensor.
new_st_size (torch.Size): Size of the new sharded tensor.
"""
st = args[0]
mask = args[1]
value = args[2]
current_rank = dist.get_rank(pg) # type: ignore[attr-defined]
sharding_dim = st.sharding_spec().dim # type: ignore[attr-defined]
narrow_idx = None
for idx in range(-1, -mask.dim() - 1, -1):
if st.dim() + idx == sharding_dim and mask.size(idx) != 1:
narrow_idx = idx
if narrow_idx is not None:
rank_idx = None
for idx, placement in enumerate(st._sharding_spec.placements): # type: ignore[attr-defined]
if placement.rank() == current_rank: # type: ignore[index]
rank_idx = idx # type: ignore[attr-defined]
shard_metadata = st.metadata().shards_metadata[rank_idx] # type: ignore[index]
mask = mask.narrow(
narrow_idx,
shard_metadata.shard_offsets[sharding_dim],
shard_metadata.shard_sizes[sharding_dim],
)
local_tensor = st.local_tensor().masked_fill(mask, value)
return local_tensor, st.sharding_spec(), st.size()
_register_sharded_op_on_local_tensor(
torch.Tensor.masked_fill,
extra_check=sharded_masked_fill_check,
customized_func=sharded_masked_fill,
)
def sharded_view_check(*args, **kwargs):
"""
Perform extra checks for the ``torch.Tensor.view`` op.
Args: same as ``torch.Tensor.view``.
Return: None
"""
st = args[0]
shape = args[1:]
if len(shape) == 0:
raise ValueError("Missing *shape for sharded view op.")
if len(shape) <= st.sharding_spec().dim:
raise NotImplementedError(
f"Shape having dim {len(shape)} is not supported "
f"for sharded tensor sharded on dim {st.sharding_spec().dim}."
)
st_size = math.prod(st.size()) # type: ignore[attr-defined]
shape_size = math.prod(shape) # type: ignore[attr-defined]
neg_sum = sum(i for i in shape if i < 0)
if shape_size > st_size or st_size % shape_size:
raise ValueError(
f"Shape '{list(shape)}' is invalid for sharded tensor size {st_size}."
)
if neg_sum < -1:
raise ValueError("Only one dimension can be inferred for sharded view op.")
def sharded_view(args, kwargs, pg):
"""
Handles ``__torch_function__`` dispatch for the ``torch.Tensor.view`` op.
For now we always keep the sharding dim after view. For example, if
a sharded tensor with size [16, 5] and sharded by 0. If we now view
it as [4, 2, 2, 5], it will still be sharded by dim 0.
Args: same as ``torch.Tensor.view``.
Return:
local_tensor (Tensor): New local tensor to build the sharded tensor.
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
sharding spec of the new sharded tensor.
new_st_size (torch.Size): Size of the new sharded tensor.
"""
st = args[0]
shape = args[1:]
try:
infer_idx = shape.index(-1)
except ValueError:
infer_idx = None
# Infer the dim which is specified with -1.
if infer_idx is not None:
st_size = math.prod(st.size()) # type: ignore[attr-defined]
shape_size = -1 * math.prod(shape) # type: ignore[attr-defined]
shape = (*shape[:infer_idx], st_size // shape_size, *shape[infer_idx + 1 :])
if st.size() == shape:
return st.local_tensor(), st.sharding_spec(), shape
sharding_dim = st.sharding_spec().dim
world_size = dist.get_world_size(pg)
if shape[sharding_dim] % world_size:
raise NotImplementedError(
f"Case when dim '({shape[sharding_dim]})' is not divisible "
"by world_size is not supported."
)
new_local_tensor_size = (
*shape[:sharding_dim],
shape[sharding_dim] // world_size,
*shape[sharding_dim + 1 :],
)
new_local_tensor = st.local_tensor().view(*new_local_tensor_size)
return new_local_tensor, st.sharding_spec(), shape
_register_sharded_op_on_local_tensor(
torch.Tensor.view,
extra_check=sharded_view_check,
customized_func=sharded_view,
)

View File

@ -9,6 +9,7 @@ from typing import (
Union Union
) )
import copy import copy
import math
import weakref import weakref
import threading import threading
@ -787,9 +788,10 @@ class ShardedTensor(object):
size = self._metadata.size size = self._metadata.size
if dim is None: if dim is None:
return size return size
if dim < 0 or dim >= len(size): if dim < -len(size) or dim >= len(size):
raise ValueError( raise ValueError(
f"Argument ``dim`` must be within the range of tensor dimensions [0, {len(size)})" "Argument ``dim`` must be within the range of tensor "
f"dimensions [-{len(size)}, {len(size)})"
) )
return size[dim] return size[dim]
@ -807,6 +809,215 @@ class ShardedTensor(object):
""" """
return self._metadata.tensor_properties.memory_format == torch.contiguous_format return self._metadata.tensor_properties.memory_format == torch.contiguous_format
def dim(self) -> int:
"""
Returns a `int` which represents the dimension of the tensor.
Returns:
A `int` represents the dimension of the tensor.
"""
return len(self._metadata.size)
def contiguous(self) -> ShardedTensor:
"""
Returns a new sharded tensor with the local tensor is made to contiguous.
"""
if self.is_contiguous():
return self
local_shards = []
for shard in self.local_shards():
local_shards.append(
Shard(shard.tensor.contiguous(), shard.metadata)
)
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards,
self._metadata,
process_group=self._process_group,
init_rrefs=self._init_rrefs,
)
def masked_fill(self, mask, value) -> ShardedTensor:
"""
Returns a new sharded tensor with each shard has been filled elements
with value where mask is True. The shape of mask must be broadcastable
with the shape of the underlying tensor.
Args:
mask (BoolTensor): the boolean mask.
value (float): the value to fill in with.
Returns:
A :class:`ShardedTensor` object whose shards have been applied masked_fill.
"""
if self.dim() < mask.dim():
raise ValueError(
"mask dim must not greater than the dim of the sharded tensor."
)
for idx in range(-1, -mask.dim() - 1, -1):
if mask.size(idx) != self.size(idx) and mask.size(idx) != 1:
raise ValueError(
f"The size of mask {mask.dim() + idx} must match the size of "
f"sharded tensor {self.dim() + idx} at non-singleton dimension {mask.dim() + idx}"
)
current_rank = dist.get_rank(self._process_group) # type: ignore[attr-defined]
sharding_dim = self.sharding_spec().dim # type: ignore[attr-defined]
narrow_idx = None
for idx in range(-1, -mask.dim() - 1, -1):
if self.dim() + idx == sharding_dim and mask.size(idx) != 1:
narrow_idx = idx
if narrow_idx is not None:
rank_idx = None
for idx, placement in enumerate(self.sharding_spec().placements): # type: ignore[attr-defined]
if placement.rank() == current_rank: # type: ignore[index]
rank_idx = idx # type: ignore[attr-defined]
shard_metadata = self.metadata().shards_metadata[rank_idx] # type: ignore[index]
mask = mask.narrow(
narrow_idx,
shard_metadata.shard_offsets[sharding_dim],
shard_metadata.shard_sizes[sharding_dim],
)
local_tensor = self.local_tensor().masked_fill(mask, value)
return ShardedTensor._init_from_local_tensor(
local_tensor,
self.sharding_spec(),
self.size(), # type: ignore[arg-type]
process_group=self._process_group,
)
def type_as(self, tensor) -> ShardedTensor:
"""
Returns a new sharded tensor with each shard has been
cast to the type of the given tensor.
Args:
tensor (Tensor): the tensor which has the desired type.
Returns:
A :class:`ShardedTensor` object whose shards have been applied type_as.
"""
if isinstance(tensor, ShardedTensor):
tensor = tensor.local_tensor()
if self.dtype == tensor.dtype:
return self
local_shards = []
for shard in self.local_shards():
local_shards.append(
Shard(shard.tensor.type_as(tensor), shard.metadata)
)
st_meta = copy.deepcopy(self._metadata)
st_meta.tensor_properties.dtype = tensor.dtype
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards,
st_meta,
process_group=self._process_group,
init_rrefs=self._init_rrefs,
)
def view(self, *shape) -> ShardedTensor:
"""
Returns a new sharded tensor with the same data as the
self tensor but of a different shape for its local tensor.
For now, we only support to pass through the view op to the local
tensor.
Args:
shape (torch.Size or int...) the desired size.
Returns:
A :class:`ShardedTensor` object whose shards have been applied
with view to its local tensor.
"""
if len(shape) == 0:
raise ValueError("Missing *shape for sharded view op.")
if len(shape) <= self.sharding_spec().dim: # type: ignore[attr-defined]
raise NotImplementedError(
f"Shape having dim {len(shape)} is not supported " # type: ignore[attr-defined]
f"for sharded tensor sharded on dim {self.sharding_spec().dim}."
)
st_size = math.prod(self.size()) # type: ignore[attr-defined]
shape_size = math.prod(shape) # type: ignore[attr-defined]
neg_sum = sum(i for i in shape if i < 0)
if shape_size > st_size or st_size % shape_size:
raise ValueError(
f"Shape '{list(shape)}' is invalid for sharded tensor size {st_size}."
)
if neg_sum < -1:
raise ValueError("Only one dimension can be inferred for sharded view op.")
try:
infer_idx = shape.index(-1)
except ValueError:
infer_idx = None # type: ignore[assignment]
# Infer the dim which is specified with -1.
if infer_idx is not None:
st_size = math.prod(self.size()) # type: ignore[attr-defined]
shape_size = -1 * math.prod(shape) # type: ignore[attr-defined]
shape = (*shape[:infer_idx], st_size // shape_size, *shape[infer_idx + 1 :])
if self.size() == shape:
return self
sharding_dim = self.sharding_spec().dim # type: ignore[attr-defined]
world_size = dist.get_world_size(self._process_group)
if shape[sharding_dim] % world_size:
raise NotImplementedError(
f"Case when dim '({shape[sharding_dim]})' is not divisible "
"by world_size is not supported."
)
new_local_tensor_size = (
*shape[:sharding_dim],
shape[sharding_dim] // world_size,
*shape[sharding_dim + 1 :],
)
return ShardedTensor._init_from_local_tensor(
self.local_tensor().view(*new_local_tensor_size).contiguous(),
self.sharding_spec(),
shape,
process_group=self._process_group,
)
def transpose(self, dim0, dim1) -> ShardedTensor:
"""
Returns a new sharded tensor with the given dimensions transposed.
During the transpose, we keep the original shading dim, e.g., if the
tensor is sharded by dim 0 and if we call transpose(1, 0). The returned
tensor will be sharded by dim 1.
Args:
dim0 (int): the first dimension to be transposed.
dim1 (int): the second dimension to be transposed.
Returns:
A :class:`ShardedTensor` object whose dims have been transposed
specified in the input.
"""
def _swap_meta_data(data, idx0, idx1):
"""
Swap the item at idx0 and idx1 in the data list.
"""
data[idx0], data[idx1] = data[idx1], data[idx0]
if not isinstance(self.sharding_spec(), shard_spec.ChunkShardingSpec):
raise NotImplementedError(
"Only ChunkShardingSpec supported for 'transpose'."
)
if dim0 == dim1:
return self
sharding_spec = copy.deepcopy(self.sharding_spec())
if sharding_spec.dim == dim0: # type: ignore[attr-defined]
sharding_spec.dim = dim1 # type: ignore[attr-defined]
elif sharding_spec.dim == dim1: # type: ignore[attr-defined]
sharding_spec.dim = dim0 # type: ignore[attr-defined]
st_size = list(self.size()) # type: ignore[arg-type]
_swap_meta_data(st_size, dim0, dim1)
return ShardedTensor._init_from_local_tensor(
self.local_tensor().transpose(dim0, dim1).contiguous(),
sharding_spec,
tuple(st_size),
process_group=self._process_group,
)
@property @property
def shape(self): def shape(self):
return self._metadata.size return self._metadata.size

View File

@ -1,7 +1,10 @@
import builtins import builtins
import torch import torch
from torch.distributed._shard.sharding_spec import ( from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec, ChunkShardingSpec,
EnumerableShardingSpec,
ShardMetadata,
) )
from torch.distributed._shard.sharding_spec._internals import ( from torch.distributed._shard.sharding_spec._internals import (
get_chunked_dim_size, get_chunked_dim_size,
@ -43,6 +46,35 @@ def generate_chunk_sharding_specs_for_test(sharding_dim):
] ]
def generate_enumerable_sharding_specs_for_test():
return [
EnumerableShardingSpec(
[
ShardMetadata(
shard_offsets=[0, 0],
shard_sizes=[5, 5],
placement="rank:0/cuda:0",
),
ShardMetadata(
shard_offsets=[5, 0],
shard_sizes=[5, 5],
placement="rank:1/cuda:1",
),
ShardMetadata(
shard_offsets=[0, 5],
shard_sizes=[5, 5],
placement="rank:2/cuda:2",
),
ShardMetadata(
shard_offsets=[5, 5],
shard_sizes=[5, 5],
placement="rank:3/cuda:3",
),
]
)
]
def generate_local_weight_sharding_params_for_test( def generate_local_weight_sharding_params_for_test(
local_weight, sharded_dim, gpu_num, spec, rank local_weight, sharded_dim, gpu_num, spec, rank
): ):
@ -87,10 +119,12 @@ def clone_module_parameter(module, param_name):
tensor = getattr(module, param_name) tensor = getattr(module, param_name)
return torch.nn.Parameter(tensor.detach().clone()) return torch.nn.Parameter(tensor.detach().clone())
def gen_binary_op_func(python_op): def gen_binary_op_func(python_op, inplace=False):
src_lines = ['def f(lhs, rhs):'] src_lines = ['def f(lhs, rhs):']
if "torch" in python_op: if "torch" in python_op:
src_lines.append(f' return {python_op}(lhs, rhs)\n') src_lines.append(f' return {python_op}(lhs, rhs)\n')
elif inplace:
src_lines.append(f' lhs {python_op}= rhs\n return lhs\n')
else: else:
src_lines.append(f' return lhs {python_op} rhs\n') src_lines.append(f' return lhs {python_op} rhs\n')