Revert "[draft export] avoid storing intermediate real tensors in proxies (#154630)"

This reverts commit 5acb8d5080.

Reverted https://github.com/pytorch/pytorch/pull/154630 on behalf of https://github.com/malfet due to This still ooms, at least occasionally see 78624679a8/1 ([comment](https://github.com/pytorch/pytorch/pull/154630#issuecomment-2923759745))
This commit is contained in:
PyTorch MergeBot 2025-05-31 00:07:54 +00:00
parent faf973da5e
commit 0fab32290a
3 changed files with 7 additions and 37 deletions

View File

@ -668,32 +668,6 @@ class TestDraftExport(TestCase):
package_path=f.name,
)
@unittest.skipIf(not torch.cuda.is_available(), "Requires cuda")
def test_cuda_memory_usage(self):
# This used to OOM
class Foo(torch.nn.Module):
def forward(self, x):
for _ in range(100):
x = x + 1e-3
return x
# measure base usage
device = torch.device("cuda:0")
torch.cuda.reset_peak_memory_stats()
base_usage = torch.cuda.memory_allocated(device)
# usage with input tensor allocated
x = torch.randn(2**10, 2**10, 2**8).to(device)
x_usage = torch.cuda.memory_allocated(device)
# draft export peak memory usage
draft_export(Foo(), (x,), strict=False)
peak_mem_usage = torch.cuda.memory_stats(device)["allocated_bytes.all.peak"]
# right now it's actually exactly 4x;
# I guess original tensor, 2 tensors per add op, 1 for clone stored in node.meta["val"]
self.assertTrue((peak_mem_usage - base_usage) <= (x_usage - base_usage) * 4.0)
if __name__ == "__main__":
run_tests()

View File

@ -1059,12 +1059,10 @@ def make_fast_binary_impl(
# disable the python dispatcher to avoid decomposing detach() further
# (proxy_mode should still decompose detach() though)
def fast_detach(fake_mode, x, include_real=False):
def fast_detach(fake_mode, x):
with no_python_dispatcher(), in_kernel_invocation_manager(fake_mode):
out = torch.ops.aten.detach.default(x)
if include_real:
return FakeTensor(fake_mode, out, x.device, real_tensor=x.real_tensor)
return FakeTensor(fake_mode, out, x.device)
return FakeTensor(fake_mode, out, x.device, real_tensor=x.real_tensor)
@functools.lru_cache(None)

View File

@ -367,12 +367,12 @@ def get_proxy_slot(
return res
def snapshot_fake(val: Tensor, include_real: bool = False) -> Optional[Tensor]:
def snapshot_fake(val: Tensor) -> Optional[Tensor]:
# val.detach() will also eventually call fast_detach(),
# but this saves us a full trip into __torch_dispatch__
# (snapshot_fake is called a lot)
if isinstance(val, FakeTensor):
return fast_detach(val.fake_mode, val, include_real)
return fast_detach(val.fake_mode, val)
else:
return val.detach()
@ -393,9 +393,9 @@ _ExtractValType = Optional[
]
def extract_val(val: _ExtractValType, include_real: bool = False) -> _ExtractValType:
def extract_val(val: _ExtractValType) -> _ExtractValType:
if is_fake(val):
return snapshot_fake(val, include_real=include_real)
return snapshot_fake(val)
elif isinstance(val, py_sym_types):
return val
elif isinstance(val, _AnyScriptObject):
@ -494,9 +494,7 @@ def maybe_enable_thunkify() -> Generator[None, None, None]:
# grad_fn, _base (_base actually may be set due to recursive call to
# ADInplaceOrView, but you shouldn't rely on it.)
def set_meta(proxy: Proxy, val: _ExtractValType) -> Proxy:
proxy.node.meta["val"] = extract_val(
val, include_real=(proxy.node.op == "placeholder")
)
proxy.node.meta["val"] = extract_val(val)
with _enable_thunkify(proxy.tracer): # type: ignore[arg-type]
# Best effort tensor_meta setting; prefer using val!