[_shard] only check shard metadata for copy_ (#82655)

copy_ does not restrict on tensor properites, it does not check things like requires_grad or dtype, so only check if the shard metadata are the same

Differential Revision: [D38359176](https://our.internmc.facebook.com/intern/diff/D38359176/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82655
Approved by: https://github.com/fduwjj
This commit is contained in:
Wanchao Liang 2022-08-02 16:44:58 -07:00 committed by PyTorch MergeBot
parent 0a919e8bd6
commit cda8635a5e
2 changed files with 13 additions and 6 deletions

View File

@ -62,6 +62,14 @@ class TestTensorOps(ShardedTensorTestBase):
st.copy_(ones_st)
self.assertTrue(torch.equal(st, ones_st))
# no grad inplace_copy should work between two with different requires_grad
st_with_grad = sharded_tensor.rand(spec, (12, 5), requires_grad=True)
self.assertTrue(st_with_grad.requires_grad)
self.assertFalse(ones_st.requires_grad)
with torch.no_grad():
st_with_grad.copy_(ones_st)
self.assertEqual(st_with_grad.local_tensor(), ones_st.local_tensor())
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()

View File

@ -126,9 +126,8 @@ def sharded_inplace_copy(types, args, kwargs, pg):
self_st = args[0]
new_st = args[1]
nonblocking = kwargs.get("non_blocking", False)
self_meta = self_st.metadata()
new_meta = new_st.metadata()
if self_meta != new_meta:
for local_shard, new_shard in zip(self_st.local_shards(), new_st.local_shards()):
if local_shard.metadata != new_shard.metadata:
raise RuntimeError(
"inplace copy can only happen between two ShardedTensor with same metadata!"
)