diff --git a/test/distributed/_shard/sharded_tensor/ops/test_tensor_ops.py b/test/distributed/_shard/sharded_tensor/ops/test_tensor_ops.py index 58aa774cd05..977fa701b44 100644 --- a/test/distributed/_shard/sharded_tensor/ops/test_tensor_ops.py +++ b/test/distributed/_shard/sharded_tensor/ops/test_tensor_ops.py @@ -62,6 +62,14 @@ class TestTensorOps(ShardedTensorTestBase): st.copy_(ones_st) self.assertTrue(torch.equal(st, ones_st)) + # no grad inplace_copy should work between two with different requires_grad + st_with_grad = sharded_tensor.rand(spec, (12, 5), requires_grad=True) + self.assertTrue(st_with_grad.requires_grad) + self.assertFalse(ones_st.requires_grad) + with torch.no_grad(): + st_with_grad.copy_(ones_st) + self.assertEqual(st_with_grad.local_tensor(), ones_st.local_tensor()) + @with_comms(init_rpc=False) @skip_if_lt_x_gpu(TEST_GPU_NUM) @requires_nccl() diff --git a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py index 2c9d0df4d84..4d0600e92f8 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py @@ -126,12 +126,11 @@ def sharded_inplace_copy(types, args, kwargs, pg): self_st = args[0] new_st = args[1] nonblocking = kwargs.get("non_blocking", False) - self_meta = self_st.metadata() - new_meta = new_st.metadata() - if self_meta != new_meta: - raise RuntimeError( - "inplace copy can only happen between two ShardedTensor with same metadata!" - ) + for local_shard, new_shard in zip(self_st.local_shards(), new_st.local_shards()): + if local_shard.metadata != new_shard.metadata: + raise RuntimeError( + "inplace copy can only happen between two ShardedTensor with same metadata!" + ) for local_shard, new_shard in zip(self_st.local_shards(), new_st.local_shards()): local_shard.tensor.copy_(new_shard.tensor, nonblocking)