mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "faketensor: prevent deepcopy from cloning FakeTensorMode (#104476)"
This reverts commit c54afea6ee.
Reverted https://github.com/pytorch/pytorch/pull/104476 on behalf of https://github.com/jeanschmidt due to sadly it is breaking internal tests, and I can't coordinate a FF due to timezone differences ([comment](https://github.com/pytorch/pytorch/pull/104476#issuecomment-1661808343))
This commit is contained in:
parent
d528a137e0
commit
fdd4b3aaa8
|
|
@ -1014,15 +1014,6 @@ class FakeTensorOperatorInvariants(TestCase):
|
|||
|
||||
self.assertEqual(ref.size(), meta_out.size())
|
||||
|
||||
def test_module_deepcopy(self):
|
||||
import copy
|
||||
from torch._guards import detect_fake_mode
|
||||
with FakeTensorMode() as m:
|
||||
lin1 = torch.nn.Linear(2, 2)
|
||||
lin2 = copy.deepcopy(lin1)
|
||||
all_params = list(lin1.parameters()) + list(lin2.parameters())
|
||||
curr_mode = detect_fake_mode(all_params)
|
||||
self.assertTrue(curr_mode is m)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
|
|
|
|||
|
|
@ -1028,16 +1028,6 @@ class FakeTensor(torch.Tensor):
|
|||
def from_tensor(t, fake_mode):
|
||||
return fake_mode.from_tensor(t)
|
||||
|
||||
# FakeTensorMode is meant to be a singleton, so deepcopying
|
||||
# should not introduce a fresh mode.
|
||||
# This just implements the "default" deepcopy, but without deepcopying
|
||||
# the fake_mode.
|
||||
def __deepcopy__(self, memo):
|
||||
result = torch.Tensor.__deepcopy__(self, memo)
|
||||
assert isinstance(result, FakeTensor)
|
||||
result.fake_mode = self.fake_mode
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@count
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user