mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[PT-D][Sharding] Enable ops needed in the transformer model training (#75374)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/75374 From the code base of FairSeq and MetaSeq codebase (which is essentially a transformer model), we have found that loads of ops are not supported by sharded tensor. So we now implement a simple version so that we can at least run a transformer example: Ops include: chuck, transpose, view, mask_fill, dropout, softmax and type_as. Isolate the common logic of registering simple ops into a function and for future register, we just need to implement at most three functions for a new op. ghstack-source-id: 155309147 Test Plan: CI Reviewed By: pritamdamania87 Differential Revision: D35123021 fbshipit-source-id: 660e559fb8b4a910eb63e0586c63ab927873a2ce (cherry picked from commit 83a87ebf627d863448dfe1019c7c5f7112cc14ab)
This commit is contained in:
parent
47e7b12d39
commit
7c44d560ba
89
test/distributed/_shard/sharded_tensor/ops/test_chunk.py
Normal file
89
test/distributed/_shard/sharded_tensor/ops/test_chunk.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard import sharded_tensor, _shard_tensor
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
run_tests,
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
TEST_GPU_NUM,
|
||||
ShardedTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
|
||||
generate_chunk_sharding_specs_for_test,
|
||||
generate_enumerable_sharding_specs_for_test,
|
||||
)
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
class TestShardedTensorChunkOps(ShardedTensorTestBase):
|
||||
def _compare_chunk_result(self, chunked_list, chunked_st_list):
|
||||
self.assertEqual(len(chunked_list), len(chunked_st_list))
|
||||
for idx, chunked_st in enumerate(chunked_st_list):
|
||||
tensor = chunked_list[idx]
|
||||
st = _shard_tensor(tensor.contiguous(), chunked_st.sharding_spec())
|
||||
# _shard_tensor generate sharded tensor with metadata ranked by # of rank.
|
||||
st._metadata.shards_metadata.sort(
|
||||
key=lambda x: x.shard_offsets[chunked_st.sharding_spec().dim],
|
||||
)
|
||||
self.assertTrue(torch.allclose(chunked_st, st))
|
||||
|
||||
def _run_sharded_chunk_test(self, local_tensor_size, shard_spec, chunk_num):
|
||||
torch.manual_seed(0)
|
||||
local_tensor = torch.rand(*local_tensor_size).cuda(self.rank)
|
||||
st_tensor = _shard_tensor(local_tensor.clone().detach(), shard_spec)
|
||||
local_tensor_chunked = torch.chunk(local_tensor, chunk_num, dim=-1)
|
||||
chunked_st = torch.chunk(st_tensor, chunk_num, dim=-1)
|
||||
self._compare_chunk_result(local_tensor_chunked, chunked_st)
|
||||
# TODO: Add sharded tensor chunk test cases back after making st a subclass of Tensor.
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_chunk(self):
|
||||
sharding_dims = [0]
|
||||
specs = []
|
||||
for dim in sharding_dims:
|
||||
specs.extend(generate_chunk_sharding_specs_for_test(dim))
|
||||
for spec in specs:
|
||||
self._run_sharded_chunk_test([17, 14], spec, 3)
|
||||
self._run_sharded_chunk_test([17, 15, 20], spec, 5)
|
||||
self._run_sharded_chunk_test([17, 16], spec, 2)
|
||||
# Large matrix case.
|
||||
self._run_sharded_chunk_test([128, 512], spec, 8)
|
||||
self._run_sharded_chunk_test([1024, 2048], spec, 4)
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_chunk_error(self):
|
||||
chunk_spec = generate_chunk_sharding_specs_for_test(-1)
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError, "Chunk by sharding dim is not supported."
|
||||
):
|
||||
st = sharded_tensor.rand(chunk_spec[0], [17, 24])
|
||||
torch.chunk(st, 5, dim=-1)
|
||||
enumerable_spec = generate_enumerable_sharding_specs_for_test()
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError, "Only ChunkShardingSpec is supported for chunk."
|
||||
):
|
||||
st = sharded_tensor.rand(enumerable_spec[0], [10, 10])
|
||||
torch.chunk(st, 5, dim=-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -30,14 +30,18 @@ if TEST_WITH_DEV_DBG_ASAN:
|
|||
|
||||
|
||||
class TestShardedTensorElementWiseOps(ShardedTensorTestBase):
|
||||
def _run_sharded_elementwise_ops(self, spec, input_size, op):
|
||||
def _run_sharded_elementwise_ops(
|
||||
self, spec, input_size, op, reset_seed=None, **kwargs
|
||||
):
|
||||
torch.manual_seed(self.rank)
|
||||
st = sharded_tensor.rand(spec, *input_size)
|
||||
new_st = op(st)
|
||||
reset_seed() if reset_seed else None
|
||||
new_st = op(st, **kwargs)
|
||||
local_shard = st.local_tensor()
|
||||
new_st_local_shard = new_st.local_tensor()
|
||||
reset_seed() if reset_seed else None
|
||||
self.assertEqual(
|
||||
op(local_shard),
|
||||
op(local_shard, **kwargs),
|
||||
new_st_local_shard,
|
||||
)
|
||||
|
||||
|
|
@ -67,6 +71,37 @@ class TestShardedTensorElementWiseOps(ShardedTensorTestBase):
|
|||
self._run_sharded_elementwise_ops(spec, [17, 23], torch.nn.functional.relu)
|
||||
self._run_sharded_elementwise_ops(spec, [14, 15], torch.nn.functional.relu)
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_dropout(self):
|
||||
def _reset_random_seed():
|
||||
torch.manual_seed(self.rank + 4)
|
||||
|
||||
specs = generate_chunk_sharding_specs_for_test(
|
||||
0
|
||||
) + generate_chunk_sharding_specs_for_test(1)
|
||||
for spec in specs:
|
||||
self._run_sharded_elementwise_ops(
|
||||
spec,
|
||||
[12, 17],
|
||||
torch.nn.functional.dropout,
|
||||
p=0.4,
|
||||
reset_seed=_reset_random_seed,
|
||||
)
|
||||
self._run_sharded_elementwise_ops(
|
||||
spec,
|
||||
[18, 21],
|
||||
torch.nn.functional.dropout,
|
||||
p=0.5,
|
||||
reset_seed=_reset_random_seed,
|
||||
)
|
||||
_reset_random_seed()
|
||||
dropout = torch.nn.Dropout(p=0.8)
|
||||
self._run_sharded_elementwise_ops(
|
||||
spec, [17, 23], dropout, reset_seed=_reset_random_seed
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard import _shard_tensor
|
||||
import torch.distributed._shard.sharded_tensor as sharded_tensor
|
||||
import torch.distributed as dist
|
||||
|
||||
|
|
@ -15,17 +16,19 @@ from torch.testing._internal.common_distributed import (
|
|||
)
|
||||
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
TEST_GPU_NUM,
|
||||
ShardedTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
|
||||
gen_binary_op_func
|
||||
gen_binary_op_func,
|
||||
generate_chunk_sharding_specs_for_test,
|
||||
)
|
||||
|
||||
class TestMathOps(ShardedTensorTestBase):
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_basic_math_ops(self):
|
||||
ops = ["torch.add", "torch.sub", "torch.mul", "torch.div", "+", "-", "*", "/"]
|
||||
|
|
@ -83,7 +86,7 @@ class TestMathOps(ShardedTensorTestBase):
|
|||
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_math_ops_errors(self):
|
||||
spec = ChunkShardingSpec(
|
||||
|
|
@ -128,3 +131,57 @@ class TestMathOps(ShardedTensorTestBase):
|
|||
|
||||
with self.assertRaisesRegex(TypeError, 'with ChunkShardingSpec supports'):
|
||||
torch.add(st, sharded_rhs)
|
||||
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_bmm(self):
|
||||
for spec in generate_chunk_sharding_specs_for_test(0):
|
||||
lhs = torch.rand(15, 4, 5).cuda(self.rank)
|
||||
rhs = torch.rand(15, 5, 6).cuda(self.rank)
|
||||
tensor = lhs.bmm(rhs)
|
||||
st_lhs = _shard_tensor(lhs, spec)
|
||||
st_rhs = _shard_tensor(rhs, spec)
|
||||
st_expected = _shard_tensor(tensor, spec)
|
||||
st_expected._metadata.shards_metadata.sort(
|
||||
key=lambda x: x.shard_offsets[0],
|
||||
)
|
||||
self.assertTrue(torch.allclose(torch.bmm(st_lhs, st_rhs), st_expected))
|
||||
# TODO: Add back test bases after we make ShardedTensor subclass of Tensor.
|
||||
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_bmm_errors(self):
|
||||
specs = generate_chunk_sharding_specs_for_test(0)
|
||||
st_lhs = sharded_tensor.rand(specs[0], (15, 5, 6))
|
||||
st_rhs = sharded_tensor.rand(specs[1], (15, 5, 6))
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'Both st and st2 need to have same placements for bmm',
|
||||
):
|
||||
torch.bmm(st_lhs, st_rhs)
|
||||
for spec in specs:
|
||||
st_lhs = sharded_tensor.rand(spec, (20, 3))
|
||||
st_rhs = sharded_tensor.rand(spec, (20, 3))
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
'both st and st2 need to be a 3D ShardedTensor',
|
||||
):
|
||||
torch.bmm(st_lhs, st_rhs)
|
||||
rhs = torch.rand(15, 5, 6).cuda(self.rank)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
'st2 needs to be a ShardedTensor for torch.bmm',
|
||||
):
|
||||
torch.bmm(st_lhs, rhs)
|
||||
spec.dim = 1
|
||||
st_lhs = sharded_tensor.rand(spec, (15, 5, 6))
|
||||
st_rhs = sharded_tensor.rand(spec, (15, 5, 6))
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'Only support performing bmm on tensors sharded on dim 0 now',
|
||||
):
|
||||
torch.bmm(st_lhs, st_rhs)
|
||||
|
|
|
|||
263
test/distributed/_shard/sharded_tensor/ops/test_matrix_ops.py
Normal file
263
test/distributed/_shard/sharded_tensor/ops/test_matrix_ops.py
Normal file
|
|
@ -0,0 +1,263 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import copy
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard import sharded_tensor, _shard_tensor
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
run_tests,
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
TEST_GPU_NUM,
|
||||
ShardedTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
|
||||
generate_enumerable_sharding_specs_for_test,
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
|
||||
_chunk_sharding_specs_list_for_test,
|
||||
)
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
class TestShardedTensorMatrixOps(ShardedTensorTestBase):
|
||||
@with_comms(init_rpc=True)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_tensor_contiguous(self):
|
||||
specs = _chunk_sharding_specs_list_for_test([0], seed=7)
|
||||
for spec in specs:
|
||||
st = sharded_tensor.rand(spec, 10, 22, 5, init_rrefs=True)
|
||||
st = st.transpose(1, 0)
|
||||
st = st.contiguous()
|
||||
self.assertTrue(st.is_contiguous())
|
||||
self.assertTrue(st.local_tensor().is_contiguous())
|
||||
|
||||
@with_comms(init_rpc=True)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_tensor_type_as(self):
|
||||
specs = _chunk_sharding_specs_list_for_test([0], seed=7)
|
||||
for spec in specs:
|
||||
st = sharded_tensor.rand(
|
||||
spec, 16, 30, 5, init_rrefs=True, dtype=torch.double
|
||||
)
|
||||
st_2 = sharded_tensor.rand(
|
||||
spec, 16, 30, 5, init_rrefs=True, dtype=torch.float
|
||||
)
|
||||
st_3 = st.type_as(st_2)
|
||||
self.assertEqual(torch.float, st_3.dtype)
|
||||
self.assertEqual(torch.float, st_3.local_tensor().dtype)
|
||||
st_3 = st.type_as(torch.zeros(10).type(torch.BoolTensor).cuda())
|
||||
self.assertEqual(torch.bool, st_3.dtype)
|
||||
self.assertEqual(torch.bool, st_3.local_tensor().dtype)
|
||||
|
||||
@with_comms(init_rpc=True)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_tensor_transpose(self):
|
||||
specs = _chunk_sharding_specs_list_for_test([0, 1, 2], seed=7)
|
||||
for spec in specs:
|
||||
tensor = torch.rand(15, 27, 16).cuda(self.rank)
|
||||
tensor_t = tensor.transpose(0, 1).contiguous()
|
||||
spec_n = copy.deepcopy(spec)
|
||||
if spec_n.dim in (0, 1):
|
||||
spec_n.dim = 1 - spec_n.dim
|
||||
st_expected = _shard_tensor(tensor_t, spec_n)
|
||||
st_expected._metadata.shards_metadata.sort(
|
||||
key=lambda x: x.shard_offsets[0],
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
torch.transpose(_shard_tensor(tensor, spec), 0, 1), st_expected
|
||||
)
|
||||
)
|
||||
tensor_t = torch.transpose(tensor, 1, 2).contiguous()
|
||||
spec_n = copy.deepcopy(spec)
|
||||
if spec_n.dim in (1, 2):
|
||||
spec_n.dim = 3 - spec_n.dim
|
||||
st_expected = _shard_tensor(tensor_t, spec_n)
|
||||
st_expected._metadata.shards_metadata.sort(
|
||||
key=lambda x: x.shard_offsets[0],
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(_shard_tensor(tensor, spec).transpose(1, 2), st_expected)
|
||||
)
|
||||
|
||||
@with_comms(init_rpc=True)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_tensor_transpose_error(self):
|
||||
enumerable_spec = generate_enumerable_sharding_specs_for_test()[0]
|
||||
st = sharded_tensor.rand(
|
||||
enumerable_spec, 10, 10, init_rrefs=True, dtype=torch.double
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"Only ChunkShardingSpec supported for 'transpose'",
|
||||
):
|
||||
st.transpose(1, 0)
|
||||
|
||||
@with_comms(init_rpc=True)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_tensor_softmax(self):
|
||||
specs = _chunk_sharding_specs_list_for_test([0, 2], seed=17)
|
||||
for spec in specs:
|
||||
tensor = torch.rand(15, 27, 16).cuda(self.rank)
|
||||
tensor_n = torch.nn.functional.softmax(tensor, dim=1, dtype=torch.float32)
|
||||
st_expected = _shard_tensor(tensor_n, spec)
|
||||
st_expected._metadata.shards_metadata.sort(
|
||||
key=lambda x: x.shard_offsets[0],
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
torch.nn.functional.softmax(
|
||||
_shard_tensor(tensor, spec), dim=1, dtype=torch.float32
|
||||
),
|
||||
st_expected,
|
||||
)
|
||||
)
|
||||
|
||||
@with_comms(init_rpc=True)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_tensor_softmax_error(self):
|
||||
specs = _chunk_sharding_specs_list_for_test([0, 2], seed=17)
|
||||
for spec in specs:
|
||||
st = sharded_tensor.rand(
|
||||
spec, 16, 30, 5, init_rrefs=True, dtype=torch.double
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"Only support performing softmax on non-sharding dim now.",
|
||||
):
|
||||
torch.nn.functional.softmax(
|
||||
st, dim=st.sharding_spec().dim, dtype=torch.float32
|
||||
)
|
||||
|
||||
def _test_masked_fill_with_sizes(self, mask_size, broadcast_style=False):
|
||||
specs = _chunk_sharding_specs_list_for_test([0, 1, 2], seed=7)
|
||||
for spec in specs:
|
||||
tensor = torch.rand(35, 17, 26).cuda(self.rank)
|
||||
mask = torch.randint(0, 2, mask_size).type(torch.BoolTensor).cuda(self.rank)
|
||||
if broadcast_style:
|
||||
mask = mask.unsqueeze(1)
|
||||
tensor_m = tensor.masked_fill(mask, 25.0)
|
||||
st_expected = _shard_tensor(tensor_m, spec)
|
||||
st_expected._metadata.shards_metadata.sort(
|
||||
key=lambda x: x.shard_offsets[0],
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
_shard_tensor(tensor, spec).masked_fill(mask, 25.0),
|
||||
st_expected,
|
||||
)
|
||||
)
|
||||
|
||||
@with_comms(init_rpc=True)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_tensor_masked_fill(self):
|
||||
self._test_masked_fill_with_sizes((35, 17, 26))
|
||||
self._test_masked_fill_with_sizes((17, 26))
|
||||
self._test_masked_fill_with_sizes((35, 26), broadcast_style=True)
|
||||
self._test_masked_fill_with_sizes((26,))
|
||||
|
||||
@with_comms(init_rpc=True)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_tensor_masked_fill_error(self):
|
||||
specs = _chunk_sharding_specs_list_for_test([1, 2], seed=7)
|
||||
for spec in specs:
|
||||
st = sharded_tensor.rand(
|
||||
spec, 35, 17, 26, init_rrefs=True, dtype=torch.double
|
||||
)
|
||||
mask = (
|
||||
torch.randint(0, 2, (2, 35, 17, 26))
|
||||
.type(torch.BoolTensor)
|
||||
.cuda(self.rank)
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"mask dim must not greater than the dim of the sharded tensor.",
|
||||
):
|
||||
st.masked_fill(mask, 25.0)
|
||||
mask = torch.randint(0, 2, (16, 26)).type(torch.BoolTensor).cuda(self.rank)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The size of mask 0 must match the size of sharded tensor 1 "
|
||||
"at non-singleton dimension 0",
|
||||
):
|
||||
st.masked_fill(mask, 25.0)
|
||||
|
||||
@with_comms(init_rpc=True)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_tensor_view(self):
|
||||
specs = _chunk_sharding_specs_list_for_test([0, 0], seed=10)
|
||||
for spec in specs:
|
||||
tensor = torch.rand(16, 35, 26).cuda(self.rank)
|
||||
tensor_v = tensor.view(16, 35, 26).view(4, 4, 35, 26)
|
||||
st_expected = _shard_tensor(tensor_v, spec)
|
||||
st_expected._metadata.shards_metadata.sort(
|
||||
key=lambda x: x.shard_offsets[0],
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
_shard_tensor(tensor, spec).view(4, 4, 35, 26),
|
||||
st_expected,
|
||||
)
|
||||
)
|
||||
st_expected = _shard_tensor(tensor, spec)
|
||||
st_expected._metadata.shards_metadata.sort(
|
||||
key=lambda x: x.shard_offsets[0],
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
_shard_tensor(tensor_v, spec).view(16, 35, 26),
|
||||
st_expected,
|
||||
)
|
||||
)
|
||||
|
||||
@with_comms(init_rpc=True)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_tensor_view_error(self):
|
||||
for spec in _chunk_sharding_specs_list_for_test([2], seed=7):
|
||||
st = sharded_tensor.rand(
|
||||
spec, 35, 17, 26, init_rrefs=True, dtype=torch.double
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"Shape having dim 2 is not supported "
|
||||
"for sharded tensor sharded on dim 2.",
|
||||
):
|
||||
st.view(35 * 17, 26)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"Shape '\[5, 7, 35, 17, 26\]' is invalid for sharded tensor size 15470.",
|
||||
):
|
||||
st.view(5, 7, 35, 17, 26)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Only one dimension can be inferred for sharded view op.",
|
||||
):
|
||||
st.view(5, 7, -1, -1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -206,6 +206,7 @@ WINDOWS_BLOCKLIST = [
|
|||
"distributed/_shard/sharded_tensor/test_megatron_prototype",
|
||||
"distributed/_shard/sharded_tensor/test_sharded_tensor",
|
||||
"distributed/_shard/sharded_tensor/test_sharded_tensor_reshard",
|
||||
"distributed/_shard/sharded_tensor/ops/test_chunk",
|
||||
"distributed/_shard/sharded_tensor/ops/test_elementwise_ops",
|
||||
"distributed/_shard/sharded_tensor/ops/test_embedding",
|
||||
"distributed/_shard/sharded_tensor/ops/test_embedding_bag",
|
||||
|
|
@ -213,6 +214,7 @@ WINDOWS_BLOCKLIST = [
|
|||
"distributed/_shard/sharded_tensor/ops/test_init",
|
||||
"distributed/_shard/sharded_tensor/ops/test_linear",
|
||||
"distributed/_shard/sharded_tensor/ops/test_math_ops",
|
||||
"distributed/_shard/sharded_tensor/ops/test_matrix_ops",
|
||||
"distributed/_shard/sharding_spec/test_sharding_spec",
|
||||
"distributed/_shard/sharded_optim/test_sharded_optim",
|
||||
"distributed/_shard/test_partial_tensor",
|
||||
|
|
@ -229,6 +231,7 @@ ROCM_BLOCKLIST = [
|
|||
"distributed/_shard/sharded_tensor/test_megatron_prototype",
|
||||
"distributed/_shard/sharded_tensor/test_sharded_tensor",
|
||||
"distributed/_shard/sharded_tensor/test_sharded_tensor_reshard",
|
||||
"distributed/_shard/sharded_tensor/ops/test_chunk",
|
||||
"distributed/_shard/sharded_tensor/ops/test_elementwise_ops",
|
||||
"distributed/_shard/sharded_tensor/ops/test_embedding",
|
||||
"distributed/_shard/sharded_tensor/ops/test_embedding_bag",
|
||||
|
|
@ -236,6 +239,7 @@ ROCM_BLOCKLIST = [
|
|||
"distributed/_shard/sharded_tensor/ops/test_init",
|
||||
"distributed/_shard/sharded_tensor/ops/test_linear",
|
||||
"distributed/_shard/sharded_tensor/ops/test_math_ops",
|
||||
"distributed/_shard/sharded_tensor/ops/test_matrix_ops",
|
||||
"distributed/_shard/sharding_spec/test_sharding_spec",
|
||||
"distributed/_shard/sharded_optim/test_sharded_optim",
|
||||
"distributed/_shard/test_partial_tensor",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import torch.distributed._shard.sharded_tensor._ops.chunk
|
||||
import torch.distributed._shard.sharded_tensor._ops.elementwise_ops
|
||||
import torch.distributed._shard.sharded_tensor._ops.math_ops
|
||||
import torch.distributed._shard.sharded_tensor._ops.matrix_ops
|
||||
|
||||
from .binary_cmp import equal, allclose
|
||||
from .embedding import sharded_embedding
|
||||
|
|
|
|||
|
|
@ -1,9 +1,16 @@
|
|||
# coding=utf-8
|
||||
|
||||
import functools
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._shard.sharding_spec as shard_spec
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
sharded_op_impl,
|
||||
Shard,
|
||||
ShardedTensor,
|
||||
)
|
||||
from torch.distributed._shard.sharding_spec._internals import (
|
||||
get_split_size,
|
||||
get_chunked_dim_size,
|
||||
|
|
@ -14,6 +21,171 @@ from torch.distributed.nn.functional import (
|
|||
)
|
||||
|
||||
|
||||
def _chunk_sharding_spec_check(spec, op):
|
||||
"""
|
||||
For the given op implementation check if the sharding spec is ChunkShardingSpec.
|
||||
"""
|
||||
if not isinstance(spec, shard_spec.ChunkShardingSpec):
|
||||
raise NotImplementedError(
|
||||
f"Only ChunkShardingSpec supported for '{op.__name__}'."
|
||||
)
|
||||
|
||||
|
||||
def _sharded_op_common(op, early_stop_func, extra_check):
|
||||
"""
|
||||
Inject sharded tensor op registration with common logics executed before
|
||||
different behaviors are done on either local shards or a local tensor.
|
||||
|
||||
Example::
|
||||
>>> op = torch.transpose
|
||||
>>> @sharded_op_impl(op)
|
||||
>>> @_sharded_op_common(op, early_stop_func, extra_check)
|
||||
>>> def sharded_tensor_op(types, args, kwargs, process_group):
|
||||
>>> ....
|
||||
>>>
|
||||
>>> st = sharded_tensor.rand(32, 16)
|
||||
>>> st.transpose(1, 2)
|
||||
>>> # This will call '_sharded_op_common'
|
||||
|
||||
Args:
|
||||
op: The op to be registered and applied to all shards of the st.
|
||||
early_stop_func (Callable, optional): the func for early stop.
|
||||
Default: if ``None``, no early stop.
|
||||
extra_check (Callable, optional): the func for extra condition check.
|
||||
Default: if ``None``, no extra check.
|
||||
|
||||
Return:
|
||||
func (Callable): Torch function for which we want to provide a sharded
|
||||
implementation (ex: torch.transpose)
|
||||
"""
|
||||
def decorator_sharded_func(wrapped_func):
|
||||
@functools.wraps(wrapped_func)
|
||||
def wrapper(types, args=(), kwargs=None, pg=None):
|
||||
if len(args) == 0:
|
||||
raise ValueError(f" No input for '{op.__name__}'!")
|
||||
# Validate types
|
||||
st = args[0]
|
||||
if not isinstance(st, ShardedTensor):
|
||||
raise TypeError(
|
||||
f"torch function '{op.__name__}', with args: {args} and "
|
||||
f"kwargs: {kwargs} are called for non ShardedTensor!"
|
||||
)
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if extra_check:
|
||||
extra_check(*args, **kwargs)
|
||||
if early_stop_func:
|
||||
early_stop = early_stop_func(*args, **kwargs)
|
||||
if early_stop:
|
||||
return st
|
||||
return wrapped_func(types, args, kwargs, pg)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator_sharded_func
|
||||
|
||||
|
||||
def _register_sharded_op_on_local_shards(
|
||||
op, early_stop_func=None, extra_check=None, customized_func=None
|
||||
):
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for ops which are performed on
|
||||
each shard of the sharded tensor such as elementwise op like
|
||||
``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
|
||||
|
||||
For more complicated ops, a customized func can be used to generate
|
||||
the new shards and sharded tensor size.
|
||||
|
||||
Args:
|
||||
op: The op to be registered and applied to all shards of the st.
|
||||
early_stop_func (Callable, optional): the func for early stop.
|
||||
Default: if ``None``, no early stop.
|
||||
extra_check (Callable, optional): the func for extra condition check.
|
||||
Default: if ``None``, no extra check.
|
||||
customized_func (Callable, optional): the func for customized logic
|
||||
to generate new shards and sharded tensor size.
|
||||
Default: if ``None``, we simply lower to the real op call with
|
||||
all local shards of the st.
|
||||
|
||||
Return:
|
||||
func (Callable): registered implementation for sharded op for
|
||||
``__torch_function__`` dispatch.
|
||||
"""
|
||||
@sharded_op_impl(op)
|
||||
@_sharded_op_common(op, early_stop_func, extra_check)
|
||||
def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
|
||||
st = args[0]
|
||||
st_metadata = st.metadata()
|
||||
local_shards = st.local_shards()
|
||||
local_shards_new = []
|
||||
if customized_func:
|
||||
local_shards_new, st_metadata = customized_func(args, kwargs, pg)
|
||||
else:
|
||||
for local_shard in local_shards:
|
||||
args = (local_shard.tensor, *args[1:])
|
||||
local_shards_new.append(
|
||||
Shard(op(*args, **kwargs), local_shard.metadata)
|
||||
)
|
||||
return ShardedTensor._init_from_local_shards_and_global_metadata(
|
||||
local_shards_new,
|
||||
st_metadata,
|
||||
process_group=pg,
|
||||
init_rrefs=st._init_rrefs,
|
||||
)
|
||||
|
||||
|
||||
def _register_sharded_op_on_local_tensor(
|
||||
op, early_stop_func=None, extra_check=None, customized_func=None
|
||||
):
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for ops which are performed on
|
||||
the single local tensor of the sharded tensor such as op like
|
||||
``torch.nn.functional.softmax`` or ``torch.Tensor.view``.
|
||||
|
||||
For more complicated ops, a customized func can be used to generate
|
||||
the new local tensor, sharding spec and sharded tensor size.
|
||||
|
||||
Args:
|
||||
op: The op to be registered and applied to all shards of the st.
|
||||
early_stop_func (Callable, optional): the func for early stop.
|
||||
Default: if ``None``, no early stop.
|
||||
extra_check (Callable, optional): the func for extra condition check.
|
||||
Default: if ``None``, no extra check.
|
||||
customized_func (Callable, optional): the func for customized logic
|
||||
to generate the new local tensor, sharding spec and sharded tensor size.
|
||||
Default: if ``None``, we simply lower to the real op call with
|
||||
the single local tensor of the st.
|
||||
|
||||
Return:
|
||||
func (Callable): registered implementation for sharded op for
|
||||
``__torch_function__`` dispatch.
|
||||
"""
|
||||
@sharded_op_impl(op)
|
||||
@_sharded_op_common(op, early_stop_func, extra_check)
|
||||
def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
|
||||
st = args[0]
|
||||
sharding_spec = st.sharding_spec()
|
||||
_chunk_sharding_spec_check(sharding_spec, op)
|
||||
if len(st.local_shards()) != 1:
|
||||
raise TypeError(
|
||||
f"torch function '{op.__name__}', with args: {args} and "
|
||||
f"kwargs: {kwargs} only supported for single local tensor!"
|
||||
)
|
||||
st_size = st.size()
|
||||
if customized_func:
|
||||
local_tensor, sharding_spec, st_size = customized_func(args, kwargs, pg)
|
||||
else:
|
||||
args = (st.local_tensor(), *args[1:])
|
||||
local_tensor = op(*args, **kwargs)
|
||||
return ShardedTensor._init_from_local_tensor(
|
||||
local_tensor.contiguous(),
|
||||
sharding_spec,
|
||||
st_size, # type: ignore[arg-type]
|
||||
process_group=pg,
|
||||
init_rrefs=st._init_rrefs,
|
||||
)
|
||||
|
||||
|
||||
def _handle_col_wise_sharding_base(
|
||||
op_func,
|
||||
col_dim,
|
||||
|
|
|
|||
63
torch/distributed/_shard/sharded_tensor/_ops/chunk.py
Normal file
63
torch/distributed/_shard/sharded_tensor/_ops/chunk.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
import torch
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
sharded_op_impl,
|
||||
ShardedTensor,
|
||||
)
|
||||
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
|
||||
|
||||
|
||||
def register_chunk_op(op):
|
||||
@sharded_op_impl(op)
|
||||
def sharded_chunk(types, args=(), kwargs=None, pg=None):
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for the chunk op.
|
||||
If we chunk by the non-sharding dim, we just directly chunk the
|
||||
local tensor and create a list of sharded tensor based on them.
|
||||
|
||||
Warnings: Chunk by the sharding dim is not supported.
|
||||
|
||||
Args: same as ``torch.chunk``.
|
||||
|
||||
Return:
|
||||
List[ShardedTensor]: Chunk results as a list of ShardedTensor.
|
||||
"""
|
||||
st = args[0]
|
||||
chunk_num = args[1]
|
||||
dim = kwargs.get("dim")
|
||||
dim = dim if dim else 0
|
||||
|
||||
# Validate types
|
||||
if not isinstance(st, ShardedTensor):
|
||||
raise TypeError(
|
||||
f"torch function '{op.__name__}', with args: {args} and "
|
||||
f"kwargs: {kwargs} are called for non ShardedTensor!"
|
||||
)
|
||||
spec = st.sharding_spec()
|
||||
if not isinstance(spec, ChunkShardingSpec):
|
||||
raise NotImplementedError("Only ChunkShardingSpec is supported for chunk.")
|
||||
if spec.dim == dim or st.dim() + spec.dim == dim or st.dim() + dim == spec.dim: # type: ignore[operator]
|
||||
raise NotImplementedError("Chunk by sharding dim is not supported.")
|
||||
|
||||
local_tensor = st.local_tensor()
|
||||
st_size = st.size()
|
||||
dim = dim if dim > 0 else st.dim() + dim
|
||||
results = []
|
||||
for chunk_tensor in local_tensor.chunk(chunk_num, dim=dim):
|
||||
new_st_size = (*st_size[:dim], chunk_tensor.size(dim), *st_size[dim + 1 :]) # type: ignore[index]
|
||||
results.append(
|
||||
ShardedTensor._init_from_local_tensor(
|
||||
chunk_tensor.contiguous(),
|
||||
st.sharding_spec(),
|
||||
new_st_size,
|
||||
process_group=pg,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
chunk_ops = [
|
||||
torch.chunk,
|
||||
torch.Tensor.chunk,
|
||||
]
|
||||
for op in chunk_ops:
|
||||
register_chunk_op(op)
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
import torch
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
sharded_op_impl,
|
||||
)
|
||||
|
||||
|
||||
def register_default_op(op):
|
||||
@sharded_op_impl(op)
|
||||
def tensor_default_op(types, args=(), kwargs=None, pg=None):
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for the default tensor ops that
|
||||
behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or
|
||||
``torch.Tensor.dtype``. We simply lower to the real op call with
|
||||
DisableTorchFunction context like ``torch.Tensor.__torch_function__``
|
||||
to avoid recursions.
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
with torch._C.DisableTorchFunction():
|
||||
return op(*args, **kwargs)
|
||||
|
||||
# Tensor properties access
|
||||
register_default_op(torch.Tensor.requires_grad.__get__) # type: ignore[attr-defined]
|
||||
register_default_op(torch.Tensor.shape.__get__) # type: ignore[attr-defined]
|
||||
register_default_op(torch.Tensor.dtype.__get__) # type: ignore[attr-defined]
|
||||
register_default_op(torch.Tensor.layout.__get__) # type: ignore[attr-defined]
|
||||
register_default_op(torch.Tensor.size)
|
||||
register_default_op(torch.Tensor.dim)
|
||||
register_default_op(torch.Tensor.ndim.__get__) # type: ignore[attr-defined]
|
||||
register_default_op(torch.Tensor.is_contiguous)
|
||||
register_default_op(torch.Tensor.contiguous)
|
||||
|
||||
# __reduce_ex__ to dispatch to get_state/set_state
|
||||
register_default_op(torch.Tensor.__reduce_ex__)
|
||||
|
|
@ -1,30 +1,10 @@
|
|||
import torch
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
sharded_op_impl,
|
||||
Shard,
|
||||
ShardedTensor,
|
||||
|
||||
from ._common import (
|
||||
_register_sharded_op_on_local_shards,
|
||||
)
|
||||
|
||||
|
||||
def register_elementwise_op(op):
|
||||
@sharded_op_impl(op)
|
||||
def elementwise_op(types, args=(), kwargs=None, pg=None):
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for the elementwise op such
|
||||
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
|
||||
This method computes on either a normal tensor or a sharded tensor.
|
||||
"""
|
||||
input = args[0]
|
||||
# Validate types
|
||||
if not isinstance(input, ShardedTensor):
|
||||
raise TypeError("input needs to be a ShardedTensor")
|
||||
local_shards_new = []
|
||||
for local_shard in input.local_shards():
|
||||
local_shards_new.append(Shard(op(local_shard.tensor), local_shard.metadata))
|
||||
return ShardedTensor._init_from_local_shards_and_global_metadata(
|
||||
local_shards_new, input.metadata(), process_group=pg
|
||||
)
|
||||
|
||||
|
||||
register_elementwise_op(torch.nn.functional.gelu)
|
||||
register_elementwise_op(torch.nn.functional.relu)
|
||||
_register_sharded_op_on_local_shards(torch.nn.functional.gelu)
|
||||
_register_sharded_op_on_local_shards(torch.nn.functional.relu)
|
||||
_register_sharded_op_on_local_shards(torch.nn.functional.dropout)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,12 @@ from torch.distributed._shard.replicated_tensor import ReplicatedTensor
|
|||
|
||||
from torch.distributed._shard._utils import narrow_tensor
|
||||
|
||||
from ._common import (
|
||||
_chunk_sharding_spec_check,
|
||||
_register_sharded_op_on_local_tensor,
|
||||
)
|
||||
|
||||
|
||||
def register_math_op(op):
|
||||
@sharded_op_impl(op)
|
||||
def binary_math_op(types, args=(), kwargs=None, pg=None):
|
||||
|
|
@ -120,3 +126,70 @@ binary_ops = [
|
|||
|
||||
for op in binary_ops:
|
||||
register_math_op(op)
|
||||
|
||||
|
||||
def sharded_bmm_check(*args, **kwargs):
|
||||
"""
|
||||
Perform extra checks for the sharded_bmm op, for example, st2 needs to
|
||||
be a sharded tensor and both tensors need to sharded by dim 0, etc.
|
||||
|
||||
Args: same as ``torch.bmm``.
|
||||
|
||||
Return: None
|
||||
"""
|
||||
if len(args) < 2:
|
||||
raise TypeError("Needs two tensors to perform torch.bmm.")
|
||||
st = args[0]
|
||||
st2 = args[1]
|
||||
# Validate types
|
||||
if not isinstance(st2, ShardedTensor):
|
||||
raise TypeError("st2 needs to be a ShardedTensor for torch.bmm.")
|
||||
_chunk_sharding_spec_check(st2.sharding_spec(), torch.bmm)
|
||||
if st.dim() != 3 or st2.dim() != 3:
|
||||
raise TypeError("both st and st2 need to be a 3D ShardedTensor")
|
||||
if (
|
||||
st.sharding_spec().dim != st2.sharding_spec().dim # type: ignore[attr-defined]
|
||||
or st.sharding_spec().dim != 0
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Only support performing bmm on tensors sharded on dim 0 now."
|
||||
)
|
||||
if st.sharding_spec().placements != st2.sharding_spec().placements: # type: ignore[attr-defined]
|
||||
raise NotImplementedError(
|
||||
"Both st and st2 need to have same placements for bmm."
|
||||
)
|
||||
|
||||
|
||||
def sharded_bmm(args, kwargs, pg):
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for the sharded_bmm op.
|
||||
|
||||
Warning: For now we only supports the case when both tensors are sharded
|
||||
by dim 0 so that no local communication.
|
||||
|
||||
Args: same as ``torch.bmm``.
|
||||
|
||||
Return:
|
||||
local_tensor (Tensor): New local tensor to build the sharded tensor.
|
||||
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
|
||||
sharding spec of the new sharded tensor.
|
||||
new_st_size (torch.Size): Size of the new sharded tensor.
|
||||
"""
|
||||
st = args[0]
|
||||
st2 = args[1]
|
||||
local_tensor = torch.bmm(st.local_tensor(), st2.local_tensor())
|
||||
new_st_size = (*st.size()[:-1], st2.size(-1))
|
||||
return local_tensor, st.sharding_spec(), new_st_size
|
||||
|
||||
|
||||
_register_sharded_op_on_local_tensor(
|
||||
torch.Tensor.bmm,
|
||||
extra_check=sharded_bmm_check,
|
||||
customized_func=sharded_bmm,
|
||||
)
|
||||
|
||||
_register_sharded_op_on_local_tensor(
|
||||
torch.bmm,
|
||||
extra_check=sharded_bmm_check,
|
||||
customized_func=sharded_bmm,
|
||||
)
|
||||
|
|
|
|||
324
torch/distributed/_shard/sharded_tensor/_ops/matrix_ops.py
Normal file
324
torch/distributed/_shard/sharded_tensor/_ops/matrix_ops.py
Normal file
|
|
@ -0,0 +1,324 @@
|
|||
import copy
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
Shard,
|
||||
ShardedTensor,
|
||||
)
|
||||
|
||||
from ._common import (
|
||||
_chunk_sharding_spec_check,
|
||||
_register_sharded_op_on_local_tensor,
|
||||
_register_sharded_op_on_local_shards,
|
||||
)
|
||||
|
||||
|
||||
def sharded_type_as_check(*args, **kwargs):
|
||||
"""
|
||||
Perform extra checks for the sharded_type_as op such as the input needs to
|
||||
be either a Tensor or ShardedTensor.
|
||||
|
||||
Args: same as ``torch.Tensor.type_as``.
|
||||
|
||||
Return: None
|
||||
"""
|
||||
if len(args) < 2:
|
||||
raise ValueError("Needs to give a tensor to cast type as!")
|
||||
if not isinstance(args[1], torch.Tensor) and not isinstance(args[1], ShardedTensor):
|
||||
raise ValueError("Needs to give a Tensor or ShardedTensor to cast type as!")
|
||||
|
||||
|
||||
def same_dtype(*args, **kwargs):
|
||||
"""
|
||||
When the dtype is the same, return the original ShardedTensor.
|
||||
|
||||
Args: same as ``torch.Tensor.type_as``.
|
||||
|
||||
Return (bool): Whether to return early or not.
|
||||
"""
|
||||
return args[0].dtype == args[1].dtype
|
||||
|
||||
|
||||
def sharded_type_as(args, kwargs, pg):
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for the ``torch.Tensor.type_as`` op.
|
||||
|
||||
Args: same as ``torch.Tensor.type_as``.
|
||||
|
||||
Return:
|
||||
new_local_shards (List[Shard]): Local shards for the new sharded tensor.
|
||||
st_meta (ShardedTensorMetadata): Metadata of the new sharded tensor.
|
||||
"""
|
||||
st = args[0]
|
||||
tensor = args[1]
|
||||
if isinstance(tensor, ShardedTensor):
|
||||
tensor = tensor.local_tensor()
|
||||
new_local_shards = []
|
||||
for shard in st.local_shards():
|
||||
new_local_shards.append(Shard(shard.tensor.type_as(tensor), shard.metadata))
|
||||
st_meta = copy.deepcopy(st._metadata)
|
||||
st_meta.tensor_properties.dtype = tensor.dtype
|
||||
return new_local_shards, st_meta
|
||||
|
||||
|
||||
_register_sharded_op_on_local_shards(
|
||||
torch.Tensor.type_as,
|
||||
early_stop_func=same_dtype,
|
||||
extra_check=sharded_type_as_check,
|
||||
customized_func=sharded_type_as,
|
||||
)
|
||||
|
||||
|
||||
def transpose_same_dim(*args, **kwargs):
|
||||
"""
|
||||
When the dim0 and dim1 of transpose are the same, return the original ShardedTensor.
|
||||
|
||||
Args: same as ``torch.Tensor.transpose``.
|
||||
|
||||
Return (bool): Whether to return early or not.
|
||||
"""
|
||||
return args[1] == args[2]
|
||||
|
||||
|
||||
def sharded_transpose_check(*args, **kwargs):
|
||||
"""
|
||||
Perform extra checks for the sharded_transpose op such as the input needs to
|
||||
be at least 2 and the sharding spec needs to be a ChunkShardingSpec.
|
||||
|
||||
Args: same as ``torch.Tensor.type_as``.
|
||||
|
||||
Return: None
|
||||
"""
|
||||
if len(args) < 3:
|
||||
raise ValueError("Needs at least two dimensions for transpose op!")
|
||||
_chunk_sharding_spec_check(args[0].sharding_spec(), torch.Tensor.transpose)
|
||||
|
||||
|
||||
def sharded_transpose(args, kwargs, pg):
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for the ``torch.Tensor.transpose`` op.
|
||||
|
||||
Returns a new sharded tensor with the given dimensions transposed.
|
||||
During the transpose, we keep the original shading dim, if the sharding
|
||||
dim is not neither dim0 nor dim1. Otherwise, we will swap the sharding
|
||||
dim with the other input of transpose.
|
||||
|
||||
Args: (same as ``torch.Tensor.transpose``.)
|
||||
dim0 (Int): the first dimension to be transposed.
|
||||
dim1 (Int): the second dimension to be transposed.
|
||||
|
||||
Returns:
|
||||
new_local_shards (List[Shard]): Local shards for the new sharded tensor.
|
||||
st_meta (ShardedTensorMetadata): Metadata of the new sharded tensor.
|
||||
"""
|
||||
|
||||
def _swap_meta_data(data, idx0, idx1):
|
||||
"""
|
||||
Swap the item at idx0 and idx1 in the data list.
|
||||
"""
|
||||
data[idx0], data[idx1] = data[idx1], data[idx0]
|
||||
|
||||
st = args[0]
|
||||
dim0 = args[1]
|
||||
dim1 = args[2]
|
||||
|
||||
sharding_spec = copy.deepcopy(st.sharding_spec())
|
||||
if sharding_spec.dim == dim0:
|
||||
sharding_spec.dim = dim1
|
||||
elif sharding_spec.dim == dim1:
|
||||
sharding_spec.dim = dim0
|
||||
|
||||
st_size = list(st.size())
|
||||
_swap_meta_data(st_size, dim0, dim1)
|
||||
local_tensor = st.local_tensor().transpose(dim0, dim1).contiguous()
|
||||
return local_tensor, sharding_spec, tuple(st_size)
|
||||
|
||||
|
||||
_register_sharded_op_on_local_tensor(
|
||||
torch.transpose,
|
||||
early_stop_func=transpose_same_dim,
|
||||
extra_check=sharded_transpose_check,
|
||||
customized_func=sharded_transpose,
|
||||
)
|
||||
_register_sharded_op_on_local_tensor(
|
||||
torch.Tensor.transpose,
|
||||
early_stop_func=transpose_same_dim,
|
||||
extra_check=sharded_transpose_check,
|
||||
customized_func=sharded_transpose,
|
||||
)
|
||||
|
||||
|
||||
def sharded_softmax_check(*args, **kwargs):
|
||||
"""
|
||||
Perform extra checks for ``torch.Tensor.softmax`` op for now we don't support
|
||||
doing softmax on the sharding dim.
|
||||
|
||||
Args: same as ``torch.Tensor.softmax``.
|
||||
|
||||
Return: None
|
||||
"""
|
||||
st = args[0]
|
||||
dim = kwargs.get("dim")
|
||||
dim = dim if dim is not None else 1 # If no dim specified, softmax use 1 as dim.
|
||||
if dim == st.sharding_spec().dim:
|
||||
raise NotImplementedError(
|
||||
"Only support performing softmax on non-sharding dim now."
|
||||
)
|
||||
|
||||
|
||||
_register_sharded_op_on_local_tensor(
|
||||
torch.nn.functional.softmax,
|
||||
extra_check=sharded_softmax_check,
|
||||
)
|
||||
|
||||
|
||||
def sharded_masked_fill_check(*args, **kwargs):
|
||||
"""
|
||||
Perform extra checks for the ``torch.Tensor.masked_fill`` op.
|
||||
Ensure the mask size is broadcastable with the size of
|
||||
the sharded tensor.
|
||||
|
||||
Args: same as ``torch.Tensor.masked_fill``.
|
||||
|
||||
Return: None
|
||||
"""
|
||||
st = args[0]
|
||||
mask = args[1]
|
||||
if st.dim() < mask.dim():
|
||||
raise ValueError(
|
||||
"mask dim must not greater than the dim of the sharded tensor."
|
||||
)
|
||||
for idx in range(-1, -mask.dim() - 1, -1):
|
||||
if mask.size(idx) != st.size(idx) and mask.size(idx) != 1:
|
||||
raise ValueError(
|
||||
f"The size of mask {mask.dim() + idx} must match the size of "
|
||||
f"sharded tensor {st.dim() + idx} at non-singleton dimension {mask.dim() + idx}"
|
||||
)
|
||||
|
||||
|
||||
def sharded_masked_fill(args, kwargs, pg):
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for the ``torch.Tensor.masked_fill`` op.
|
||||
We first narrow down the mask to the size of local tensor if the mask
|
||||
contains the sharding dim and then apply the mask to the local tensor.
|
||||
|
||||
Args: same as ``torch.Tensor.masked_fill``.
|
||||
|
||||
Return:
|
||||
local_tensor (Tensor): New local tensor to build the sharded tensor.
|
||||
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
|
||||
sharding spec of the new sharded tensor.
|
||||
new_st_size (torch.Size): Size of the new sharded tensor.
|
||||
"""
|
||||
st = args[0]
|
||||
mask = args[1]
|
||||
value = args[2]
|
||||
current_rank = dist.get_rank(pg) # type: ignore[attr-defined]
|
||||
sharding_dim = st.sharding_spec().dim # type: ignore[attr-defined]
|
||||
narrow_idx = None
|
||||
for idx in range(-1, -mask.dim() - 1, -1):
|
||||
if st.dim() + idx == sharding_dim and mask.size(idx) != 1:
|
||||
narrow_idx = idx
|
||||
if narrow_idx is not None:
|
||||
rank_idx = None
|
||||
for idx, placement in enumerate(st._sharding_spec.placements): # type: ignore[attr-defined]
|
||||
if placement.rank() == current_rank: # type: ignore[index]
|
||||
rank_idx = idx # type: ignore[attr-defined]
|
||||
shard_metadata = st.metadata().shards_metadata[rank_idx] # type: ignore[index]
|
||||
mask = mask.narrow(
|
||||
narrow_idx,
|
||||
shard_metadata.shard_offsets[sharding_dim],
|
||||
shard_metadata.shard_sizes[sharding_dim],
|
||||
)
|
||||
local_tensor = st.local_tensor().masked_fill(mask, value)
|
||||
return local_tensor, st.sharding_spec(), st.size()
|
||||
|
||||
|
||||
_register_sharded_op_on_local_tensor(
|
||||
torch.Tensor.masked_fill,
|
||||
extra_check=sharded_masked_fill_check,
|
||||
customized_func=sharded_masked_fill,
|
||||
)
|
||||
|
||||
|
||||
def sharded_view_check(*args, **kwargs):
|
||||
"""
|
||||
Perform extra checks for the ``torch.Tensor.view`` op.
|
||||
|
||||
Args: same as ``torch.Tensor.view``.
|
||||
|
||||
Return: None
|
||||
"""
|
||||
st = args[0]
|
||||
shape = args[1:]
|
||||
if len(shape) == 0:
|
||||
raise ValueError("Missing *shape for sharded view op.")
|
||||
if len(shape) <= st.sharding_spec().dim:
|
||||
raise NotImplementedError(
|
||||
f"Shape having dim {len(shape)} is not supported "
|
||||
f"for sharded tensor sharded on dim {st.sharding_spec().dim}."
|
||||
)
|
||||
st_size = math.prod(st.size()) # type: ignore[attr-defined]
|
||||
shape_size = math.prod(shape) # type: ignore[attr-defined]
|
||||
neg_sum = sum(i for i in shape if i < 0)
|
||||
if shape_size > st_size or st_size % shape_size:
|
||||
raise ValueError(
|
||||
f"Shape '{list(shape)}' is invalid for sharded tensor size {st_size}."
|
||||
)
|
||||
if neg_sum < -1:
|
||||
raise ValueError("Only one dimension can be inferred for sharded view op.")
|
||||
|
||||
|
||||
def sharded_view(args, kwargs, pg):
|
||||
"""
|
||||
Handles ``__torch_function__`` dispatch for the ``torch.Tensor.view`` op.
|
||||
For now we always keep the sharding dim after view. For example, if
|
||||
a sharded tensor with size [16, 5] and sharded by 0. If we now view
|
||||
it as [4, 2, 2, 5], it will still be sharded by dim 0.
|
||||
|
||||
Args: same as ``torch.Tensor.view``.
|
||||
|
||||
Return:
|
||||
local_tensor (Tensor): New local tensor to build the sharded tensor.
|
||||
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
|
||||
sharding spec of the new sharded tensor.
|
||||
new_st_size (torch.Size): Size of the new sharded tensor.
|
||||
"""
|
||||
st = args[0]
|
||||
shape = args[1:]
|
||||
try:
|
||||
infer_idx = shape.index(-1)
|
||||
except ValueError:
|
||||
infer_idx = None
|
||||
|
||||
# Infer the dim which is specified with -1.
|
||||
if infer_idx is not None:
|
||||
st_size = math.prod(st.size()) # type: ignore[attr-defined]
|
||||
shape_size = -1 * math.prod(shape) # type: ignore[attr-defined]
|
||||
shape = (*shape[:infer_idx], st_size // shape_size, *shape[infer_idx + 1 :])
|
||||
if st.size() == shape:
|
||||
return st.local_tensor(), st.sharding_spec(), shape
|
||||
|
||||
sharding_dim = st.sharding_spec().dim
|
||||
world_size = dist.get_world_size(pg)
|
||||
if shape[sharding_dim] % world_size:
|
||||
raise NotImplementedError(
|
||||
f"Case when dim '({shape[sharding_dim]})' is not divisible "
|
||||
"by world_size is not supported."
|
||||
)
|
||||
new_local_tensor_size = (
|
||||
*shape[:sharding_dim],
|
||||
shape[sharding_dim] // world_size,
|
||||
*shape[sharding_dim + 1 :],
|
||||
)
|
||||
new_local_tensor = st.local_tensor().view(*new_local_tensor_size)
|
||||
return new_local_tensor, st.sharding_spec(), shape
|
||||
|
||||
|
||||
_register_sharded_op_on_local_tensor(
|
||||
torch.Tensor.view,
|
||||
extra_check=sharded_view_check,
|
||||
customized_func=sharded_view,
|
||||
)
|
||||
|
|
@ -9,6 +9,7 @@ from typing import (
|
|||
Union
|
||||
)
|
||||
import copy
|
||||
import math
|
||||
import weakref
|
||||
|
||||
import threading
|
||||
|
|
@ -787,9 +788,10 @@ class ShardedTensor(object):
|
|||
size = self._metadata.size
|
||||
if dim is None:
|
||||
return size
|
||||
if dim < 0 or dim >= len(size):
|
||||
if dim < -len(size) or dim >= len(size):
|
||||
raise ValueError(
|
||||
f"Argument ``dim`` must be within the range of tensor dimensions [0, {len(size)})"
|
||||
"Argument ``dim`` must be within the range of tensor "
|
||||
f"dimensions [-{len(size)}, {len(size)})"
|
||||
)
|
||||
return size[dim]
|
||||
|
||||
|
|
@ -807,6 +809,215 @@ class ShardedTensor(object):
|
|||
"""
|
||||
return self._metadata.tensor_properties.memory_format == torch.contiguous_format
|
||||
|
||||
def dim(self) -> int:
|
||||
"""
|
||||
Returns a `int` which represents the dimension of the tensor.
|
||||
|
||||
Returns:
|
||||
A `int` represents the dimension of the tensor.
|
||||
"""
|
||||
return len(self._metadata.size)
|
||||
|
||||
def contiguous(self) -> ShardedTensor:
|
||||
"""
|
||||
Returns a new sharded tensor with the local tensor is made to contiguous.
|
||||
"""
|
||||
if self.is_contiguous():
|
||||
return self
|
||||
local_shards = []
|
||||
for shard in self.local_shards():
|
||||
local_shards.append(
|
||||
Shard(shard.tensor.contiguous(), shard.metadata)
|
||||
)
|
||||
return ShardedTensor._init_from_local_shards_and_global_metadata(
|
||||
local_shards,
|
||||
self._metadata,
|
||||
process_group=self._process_group,
|
||||
init_rrefs=self._init_rrefs,
|
||||
)
|
||||
|
||||
def masked_fill(self, mask, value) -> ShardedTensor:
|
||||
"""
|
||||
Returns a new sharded tensor with each shard has been filled elements
|
||||
with value where mask is True. The shape of mask must be broadcastable
|
||||
with the shape of the underlying tensor.
|
||||
|
||||
Args:
|
||||
mask (BoolTensor): the boolean mask.
|
||||
value (float): the value to fill in with.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object whose shards have been applied masked_fill.
|
||||
"""
|
||||
if self.dim() < mask.dim():
|
||||
raise ValueError(
|
||||
"mask dim must not greater than the dim of the sharded tensor."
|
||||
)
|
||||
for idx in range(-1, -mask.dim() - 1, -1):
|
||||
if mask.size(idx) != self.size(idx) and mask.size(idx) != 1:
|
||||
raise ValueError(
|
||||
f"The size of mask {mask.dim() + idx} must match the size of "
|
||||
f"sharded tensor {self.dim() + idx} at non-singleton dimension {mask.dim() + idx}"
|
||||
)
|
||||
current_rank = dist.get_rank(self._process_group) # type: ignore[attr-defined]
|
||||
sharding_dim = self.sharding_spec().dim # type: ignore[attr-defined]
|
||||
narrow_idx = None
|
||||
for idx in range(-1, -mask.dim() - 1, -1):
|
||||
if self.dim() + idx == sharding_dim and mask.size(idx) != 1:
|
||||
narrow_idx = idx
|
||||
if narrow_idx is not None:
|
||||
rank_idx = None
|
||||
for idx, placement in enumerate(self.sharding_spec().placements): # type: ignore[attr-defined]
|
||||
if placement.rank() == current_rank: # type: ignore[index]
|
||||
rank_idx = idx # type: ignore[attr-defined]
|
||||
shard_metadata = self.metadata().shards_metadata[rank_idx] # type: ignore[index]
|
||||
mask = mask.narrow(
|
||||
narrow_idx,
|
||||
shard_metadata.shard_offsets[sharding_dim],
|
||||
shard_metadata.shard_sizes[sharding_dim],
|
||||
)
|
||||
local_tensor = self.local_tensor().masked_fill(mask, value)
|
||||
return ShardedTensor._init_from_local_tensor(
|
||||
local_tensor,
|
||||
self.sharding_spec(),
|
||||
self.size(), # type: ignore[arg-type]
|
||||
process_group=self._process_group,
|
||||
)
|
||||
|
||||
def type_as(self, tensor) -> ShardedTensor:
|
||||
"""
|
||||
Returns a new sharded tensor with each shard has been
|
||||
cast to the type of the given tensor.
|
||||
|
||||
Args:
|
||||
tensor (Tensor): the tensor which has the desired type.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object whose shards have been applied type_as.
|
||||
"""
|
||||
if isinstance(tensor, ShardedTensor):
|
||||
tensor = tensor.local_tensor()
|
||||
if self.dtype == tensor.dtype:
|
||||
return self
|
||||
local_shards = []
|
||||
for shard in self.local_shards():
|
||||
local_shards.append(
|
||||
Shard(shard.tensor.type_as(tensor), shard.metadata)
|
||||
)
|
||||
st_meta = copy.deepcopy(self._metadata)
|
||||
st_meta.tensor_properties.dtype = tensor.dtype
|
||||
return ShardedTensor._init_from_local_shards_and_global_metadata(
|
||||
local_shards,
|
||||
st_meta,
|
||||
process_group=self._process_group,
|
||||
init_rrefs=self._init_rrefs,
|
||||
)
|
||||
|
||||
def view(self, *shape) -> ShardedTensor:
|
||||
"""
|
||||
Returns a new sharded tensor with the same data as the
|
||||
self tensor but of a different shape for its local tensor.
|
||||
|
||||
For now, we only support to pass through the view op to the local
|
||||
tensor.
|
||||
|
||||
Args:
|
||||
shape (torch.Size or int...) – the desired size.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object whose shards have been applied
|
||||
with view to its local tensor.
|
||||
"""
|
||||
if len(shape) == 0:
|
||||
raise ValueError("Missing *shape for sharded view op.")
|
||||
if len(shape) <= self.sharding_spec().dim: # type: ignore[attr-defined]
|
||||
raise NotImplementedError(
|
||||
f"Shape having dim {len(shape)} is not supported " # type: ignore[attr-defined]
|
||||
f"for sharded tensor sharded on dim {self.sharding_spec().dim}."
|
||||
)
|
||||
st_size = math.prod(self.size()) # type: ignore[attr-defined]
|
||||
shape_size = math.prod(shape) # type: ignore[attr-defined]
|
||||
neg_sum = sum(i for i in shape if i < 0)
|
||||
if shape_size > st_size or st_size % shape_size:
|
||||
raise ValueError(
|
||||
f"Shape '{list(shape)}' is invalid for sharded tensor size {st_size}."
|
||||
)
|
||||
if neg_sum < -1:
|
||||
raise ValueError("Only one dimension can be inferred for sharded view op.")
|
||||
try:
|
||||
infer_idx = shape.index(-1)
|
||||
except ValueError:
|
||||
infer_idx = None # type: ignore[assignment]
|
||||
|
||||
# Infer the dim which is specified with -1.
|
||||
if infer_idx is not None:
|
||||
st_size = math.prod(self.size()) # type: ignore[attr-defined]
|
||||
shape_size = -1 * math.prod(shape) # type: ignore[attr-defined]
|
||||
shape = (*shape[:infer_idx], st_size // shape_size, *shape[infer_idx + 1 :])
|
||||
if self.size() == shape:
|
||||
return self
|
||||
|
||||
sharding_dim = self.sharding_spec().dim # type: ignore[attr-defined]
|
||||
world_size = dist.get_world_size(self._process_group)
|
||||
if shape[sharding_dim] % world_size:
|
||||
raise NotImplementedError(
|
||||
f"Case when dim '({shape[sharding_dim]})' is not divisible "
|
||||
"by world_size is not supported."
|
||||
)
|
||||
new_local_tensor_size = (
|
||||
*shape[:sharding_dim],
|
||||
shape[sharding_dim] // world_size,
|
||||
*shape[sharding_dim + 1 :],
|
||||
)
|
||||
return ShardedTensor._init_from_local_tensor(
|
||||
self.local_tensor().view(*new_local_tensor_size).contiguous(),
|
||||
self.sharding_spec(),
|
||||
shape,
|
||||
process_group=self._process_group,
|
||||
)
|
||||
|
||||
def transpose(self, dim0, dim1) -> ShardedTensor:
|
||||
"""
|
||||
Returns a new sharded tensor with the given dimensions transposed.
|
||||
During the transpose, we keep the original shading dim, e.g., if the
|
||||
tensor is sharded by dim 0 and if we call transpose(1, 0). The returned
|
||||
tensor will be sharded by dim 1.
|
||||
|
||||
Args:
|
||||
dim0 (int): the first dimension to be transposed.
|
||||
dim1 (int): the second dimension to be transposed.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object whose dims have been transposed
|
||||
specified in the input.
|
||||
"""
|
||||
def _swap_meta_data(data, idx0, idx1):
|
||||
"""
|
||||
Swap the item at idx0 and idx1 in the data list.
|
||||
"""
|
||||
data[idx0], data[idx1] = data[idx1], data[idx0]
|
||||
|
||||
if not isinstance(self.sharding_spec(), shard_spec.ChunkShardingSpec):
|
||||
raise NotImplementedError(
|
||||
"Only ChunkShardingSpec supported for 'transpose'."
|
||||
)
|
||||
if dim0 == dim1:
|
||||
return self
|
||||
sharding_spec = copy.deepcopy(self.sharding_spec())
|
||||
if sharding_spec.dim == dim0: # type: ignore[attr-defined]
|
||||
sharding_spec.dim = dim1 # type: ignore[attr-defined]
|
||||
elif sharding_spec.dim == dim1: # type: ignore[attr-defined]
|
||||
sharding_spec.dim = dim0 # type: ignore[attr-defined]
|
||||
|
||||
st_size = list(self.size()) # type: ignore[arg-type]
|
||||
_swap_meta_data(st_size, dim0, dim1)
|
||||
return ShardedTensor._init_from_local_tensor(
|
||||
self.local_tensor().transpose(dim0, dim1).contiguous(),
|
||||
sharding_spec,
|
||||
tuple(st_size),
|
||||
process_group=self._process_group,
|
||||
)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._metadata.size
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
import builtins
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard.sharding_spec import (
|
||||
ChunkShardingSpec,
|
||||
EnumerableShardingSpec,
|
||||
ShardMetadata,
|
||||
)
|
||||
from torch.distributed._shard.sharding_spec._internals import (
|
||||
get_chunked_dim_size,
|
||||
|
|
@ -43,6 +46,35 @@ def generate_chunk_sharding_specs_for_test(sharding_dim):
|
|||
]
|
||||
|
||||
|
||||
def generate_enumerable_sharding_specs_for_test():
|
||||
return [
|
||||
EnumerableShardingSpec(
|
||||
[
|
||||
ShardMetadata(
|
||||
shard_offsets=[0, 0],
|
||||
shard_sizes=[5, 5],
|
||||
placement="rank:0/cuda:0",
|
||||
),
|
||||
ShardMetadata(
|
||||
shard_offsets=[5, 0],
|
||||
shard_sizes=[5, 5],
|
||||
placement="rank:1/cuda:1",
|
||||
),
|
||||
ShardMetadata(
|
||||
shard_offsets=[0, 5],
|
||||
shard_sizes=[5, 5],
|
||||
placement="rank:2/cuda:2",
|
||||
),
|
||||
ShardMetadata(
|
||||
shard_offsets=[5, 5],
|
||||
shard_sizes=[5, 5],
|
||||
placement="rank:3/cuda:3",
|
||||
),
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def generate_local_weight_sharding_params_for_test(
|
||||
local_weight, sharded_dim, gpu_num, spec, rank
|
||||
):
|
||||
|
|
@ -87,10 +119,12 @@ def clone_module_parameter(module, param_name):
|
|||
tensor = getattr(module, param_name)
|
||||
return torch.nn.Parameter(tensor.detach().clone())
|
||||
|
||||
def gen_binary_op_func(python_op):
|
||||
def gen_binary_op_func(python_op, inplace=False):
|
||||
src_lines = ['def f(lhs, rhs):']
|
||||
if "torch" in python_op:
|
||||
src_lines.append(f' return {python_op}(lhs, rhs)\n')
|
||||
elif inplace:
|
||||
src_lines.append(f' lhs {python_op}= rhs\n return lhs\n')
|
||||
else:
|
||||
src_lines.append(f' return lhs {python_op} rhs\n')
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user