mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[_shard] only check shard metadata for copy_ (#82655)
copy_ does not restrict on tensor properites, it does not check things like requires_grad or dtype, so only check if the shard metadata are the same Differential Revision: [D38359176](https://our.internmc.facebook.com/intern/diff/D38359176/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/82655 Approved by: https://github.com/fduwjj
This commit is contained in:
parent
0a919e8bd6
commit
cda8635a5e
|
|
@ -62,6 +62,14 @@ class TestTensorOps(ShardedTensorTestBase):
|
|||
st.copy_(ones_st)
|
||||
self.assertTrue(torch.equal(st, ones_st))
|
||||
|
||||
# no grad inplace_copy should work between two with different requires_grad
|
||||
st_with_grad = sharded_tensor.rand(spec, (12, 5), requires_grad=True)
|
||||
self.assertTrue(st_with_grad.requires_grad)
|
||||
self.assertFalse(ones_st.requires_grad)
|
||||
with torch.no_grad():
|
||||
st_with_grad.copy_(ones_st)
|
||||
self.assertEqual(st_with_grad.local_tensor(), ones_st.local_tensor())
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
|
|
|
|||
|
|
@ -126,12 +126,11 @@ def sharded_inplace_copy(types, args, kwargs, pg):
|
|||
self_st = args[0]
|
||||
new_st = args[1]
|
||||
nonblocking = kwargs.get("non_blocking", False)
|
||||
self_meta = self_st.metadata()
|
||||
new_meta = new_st.metadata()
|
||||
if self_meta != new_meta:
|
||||
raise RuntimeError(
|
||||
"inplace copy can only happen between two ShardedTensor with same metadata!"
|
||||
)
|
||||
for local_shard, new_shard in zip(self_st.local_shards(), new_st.local_shards()):
|
||||
if local_shard.metadata != new_shard.metadata:
|
||||
raise RuntimeError(
|
||||
"inplace copy can only happen between two ShardedTensor with same metadata!"
|
||||
)
|
||||
for local_shard, new_shard in zip(self_st.local_shards(), new_st.local_shards()):
|
||||
local_shard.tensor.copy_(new_shard.tensor, nonblocking)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user