diff --git a/test/distributed/_shard/test_replicated_tensor.py b/test/distributed/_shard/test_replicated_tensor.py new file mode 100644 index 00000000000..474fbfb90aa --- /dev/null +++ b/test/distributed/_shard/test_replicated_tensor.py @@ -0,0 +1,76 @@ +# Owner(s): ["oncall: distributed"] + +import torch + +import torch.distributed as dist + +from torch.testing._internal.common_distributed import ( + requires_nccl, + skip_if_lt_x_gpu, +) + +from torch.testing._internal.distributed._shard.sharded_tensor import ( + ShardedTensorTestBase, + with_comms, +) +from torch.distributed._shard.replicated_tensor import ReplicatedTensor + + +class TestReplicatedTensor(ShardedTensorTestBase): + + @with_comms(init_rpc=False) + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_replicated_tensor_basics(self): + local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4 + replica_tensor = ReplicatedTensor(local_tensor) + print(replica_tensor.process_group) + # validate it's a replicated tensor by checking values on all rank + validated = replica_tensor.validate() + self.assertEqual(validated, True) + res = replica_tensor + 2 + self.assertIsInstance(res, torch.Tensor) + self.assertNotIsInstance(res, ReplicatedTensor) + self.assertEqual(res, torch.ones(3, 3) * 6) + + # modify local tensor on certain rank, and test if validation raise + if self.rank == 2: + local_tensor += 3 + + with self.assertRaisesRegex(ValueError, 'have different values'): + replica_tensor.validate() + + @with_comms(init_rpc=False) + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_replicated_tensor_inter_op_replicated_tensor(self): + local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") + replica_tensor1 = ReplicatedTensor(local_tensor * 4) + replica_tensor2 = ReplicatedTensor(local_tensor * 6) + + new_tensor = replica_tensor1 * replica_tensor2 + self.assertIsInstance(new_tensor, ReplicatedTensor) + self.assertEqual(new_tensor, torch.ones(3, 3) * 24) + + # test replicated tensor inter-op with different pgs + new_pg = dist.new_group(ranks=[1, 2, 3]) + replica_tensor_new_group = ReplicatedTensor(local_tensor * 3, process_group=new_pg) + + with self.assertRaisesRegex(RuntimeError, 'must be in the same'): + replica_tensor_new_group * replica_tensor1 + + + @with_comms(init_rpc=False) + @skip_if_lt_x_gpu(4) + @requires_nccl() + def test_replicated_tensor_inter_op_tensor(self): + local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4 + replica_tensor = ReplicatedTensor(local_tensor) + + local_rand_tensor = torch.randn(3, 3, device=f"cuda:{self.rank}") + + new_tensor = replica_tensor + local_rand_tensor + self.assertIsInstance(new_tensor, torch.Tensor) + self.assertNotIsInstance(new_tensor, ReplicatedTensor) + + self.assertEqual(new_tensor, local_tensor + local_rand_tensor) diff --git a/test/run_test.py b/test/run_test.py index 176b35b64ac..d378f1521a6 100644 --- a/test/run_test.py +++ b/test/run_test.py @@ -214,6 +214,7 @@ WINDOWS_BLOCKLIST = [ "distributed/_shard/sharded_tensor/ops/test_linear", "distributed/_shard/sharding_spec/test_sharding_spec", "distributed/_shard/sharded_optim/test_sharded_optim", + "distributed/_shard/test_replicated_tensor", ] + FSDP_TEST ROCM_BLOCKLIST = [ @@ -233,6 +234,7 @@ ROCM_BLOCKLIST = [ "distributed/_shard/sharded_tensor/ops/test_linear", "distributed/_shard/sharding_spec/test_sharding_spec", "distributed/_shard/sharded_optim/test_sharded_optim", + "distributed/_shard/test_replicated_tensor", "test_determination", "test_jit_legacy", "test_type_hints", diff --git a/torch/distributed/_shard/__init__.py b/torch/distributed/_shard/__init__.py index b6f0776a36a..194ae2c6bc7 100644 --- a/torch/distributed/_shard/__init__.py +++ b/torch/distributed/_shard/__init__.py @@ -1 +1 @@ -from .api import shard_parameter, _shard_tensor +from .api import shard_parameter, _shard_tensor, _replicate_tensor diff --git a/torch/distributed/_shard/api.py b/torch/distributed/_shard/api.py index 0de8a59660e..c5082b0d9a8 100644 --- a/torch/distributed/_shard/api.py +++ b/torch/distributed/_shard/api.py @@ -7,6 +7,7 @@ from torch.distributed._shard.sharded_tensor import ( from .sharding_spec import ( ShardingSpec, ) +from .replicated_tensor import ReplicatedTensor def _shard_tensor( tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None @@ -118,3 +119,20 @@ def shard_parameter( # Now we can set the attribute appropriately. setattr(module, param_name, st) + + +def _replicate_tensor(tensor: torch.Tensor, process_group=None) -> ReplicatedTensor: + """ + Given a :class:`torch.Tensor`, mark it as a ReplicatedTensor where all + ranks have the same value. + + Args: + tensor (:class:`torch.Tensor`): the tensor to be marked as replicated. + Keyword args: + process_group (ProcessGroup, optional): The process group to replicate on. + If None, the default process group will be used. + Returns: + A :class:`ReplicatedTensor` from the given tensor. + + """ + return ReplicatedTensor(tensor, process_group=process_group) diff --git a/torch/distributed/_shard/replicated_tensor.py b/torch/distributed/_shard/replicated_tensor.py new file mode 100644 index 00000000000..12253a0b465 --- /dev/null +++ b/torch/distributed/_shard/replicated_tensor.py @@ -0,0 +1,125 @@ +import torch +import torch.distributed as dist + +from torch.overrides import get_default_nowrap_functions +from torch.distributed._shard.sharded_tensor.api import ShardedTensor +from torch.distributed import distributed_c10d + + +class ReplicatedTensor(torch.Tensor): + """ + ReplicatedTensor represents a tensor which is replicated across the `world_size` and + has the same value on each rank. + + ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together + with ShardedTensor/Tensor together to express different types of computation. The + inter-op rules defined as (using torch.add as an example op): + ReplicatedTensor + ReplicatedTensor = ReplicatedTensor + ReplicatedTensor + torch.Tensor = torch.Tensor + ReplicatedTensor + ShardedTensor = ShardedTensor + ReplicatedTensor + other type (i.e. Scalar) = other type + + NOTE: We do not gurantee equal content of ReplicatedTensor across nodes after its + construction. Although we defined proper inter-op rules to make sure ReplicatedTensor + stays the same, there's no enforcement on it (i.e. if you manually modify content on + some ranks, the modified value will not automatically get synced to other nodes). If + you wish to manually validate tensors are the same across ranks, use `validate()`. + + """ + process_group: distributed_c10d.ProcessGroup + + __slots__ = ["process_group"] + + def __new__(cls, data=None, process_group=None): + if data is None: + data = torch.empty(0) + r = torch.Tensor._make_subclass(cls, data) # type: ignore[arg-type] + r.process_group = ( # type: ignore[attr-defined] + process_group + if process_group is not None + else distributed_c10d._get_default_group() + ) + return r + + def __repr__(self): + return f"ReplicatedTensor({super(ReplicatedTensor, self).__repr__()})" + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + # We will re-dispatch the execution to ShardedTensor __torch_function__ + # if we find there're ShardedTensor operands. We will also check if args/kwargs + # are all replicated tensor operands, we have to do this to ensure we do not + # converting results back to ReplicatedTensor if not all operands are replicated. + all_replicated = True + replicated_pg = None + + def dispatch_arg(arg): + nonlocal replicated_pg, all_replicated + if isinstance(arg, ShardedTensor): + # redispatch to ShardedTensor + # TODO: handle ShardedTensor/PartialTensor inter-op with ReplicatedTensor + return arg.__torch_function__(func, types, args, kwargs) + if isinstance(arg, ReplicatedTensor): + if replicated_pg is None: + replicated_pg = arg.process_group + elif replicated_pg != arg.process_group: + raise RuntimeError( + f"ReplicatedTensor operands must be in the same process group " + f"in torch function '{func.__name__}', but found at least two " + f"ReplicatedTensor operands in different process groups! ") + else: + all_replicated = False + + for arg in args: + dispatch_arg(arg) + + if kwargs is not None: + for k, v in kwargs.items(): + dispatch_arg(v) + + # We cann't do super().__torch_function__() as it implicitly convert the result + # back to tensor subclasses, where in our case, we need to control the output type + # base on the inter-op rules we defined. + with torch._C.DisableTorchFunction(): + rs = func(*args, **kwargs) + if func in get_default_nowrap_functions(): + return rs + if all_replicated and isinstance(rs, torch.Tensor) and not isinstance(rs, cls): + # if all operands are ReplicatedTensors and does not get dispatched to ShardedTensor + # __torch_function__, result is a torch.Tensor, then we convert and return a + # ReplicatedTensor according to our inter-op rule + rs = rs.as_subclass(cls) # type: ignore[arg-type] + # propagate the process_group field to result + rs.process_group = replicated_pg # type: ignore[attr-defined] + + return rs + + def validate(self) -> bool: + """ + Validate the ReplicatedTensor is legit by all gathering tensors on all ranks + and check to make sure they are the same. + + If there's some ranks with different values, a ValueError will be raised. + + Keyword args: + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + True if validation succeed. + """ + world_size = dist.get_world_size(self.process_group) + current_rank = dist.get_rank(self.process_group) + + tensors_on_rank = [torch.empty_like(self) for _ in range(world_size)] + + dist.all_gather(tensors_on_rank, self, group=self.process_group) + # validate and check if all tensors are equal + for rank, tensor in enumerate(tensors_on_rank): + if not torch.allclose(self, tensor): + raise ValueError( + f"ReplicatedTensor have different values on rank {current_rank} and {rank}") + + return True diff --git a/torch/distributed/_shard/sharded_tensor/__init__.py b/torch/distributed/_shard/sharded_tensor/__init__.py index ba1cdf326c2..58b5e022747 100644 --- a/torch/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/distributed/_shard/sharded_tensor/__init__.py @@ -366,7 +366,7 @@ def sharded_op_impl(func): parameters, the function provided will be invoked for that operator. Example:: - >>> @custom_sharded_op(torch.nn.functional.linear) + >>> @sharded_op_impl(torch.nn.functional.linear) >>> def my_custom_sharded_linear(types, args, kwargs, process_group): >>> .... >>>